diff --git a/src/Common/AssemblyLoadingUtils.cs b/src/Common/AssemblyLoadingUtils.cs new file mode 100644 index 0000000000..64477aa7e0 --- /dev/null +++ b/src/Common/AssemblyLoadingUtils.cs @@ -0,0 +1,290 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime.Internal.Utilities; +using System; +using System.IO; +using System.IO.Compression; +using System.Reflection; + +namespace Microsoft.ML.Runtime +{ + internal static class AssemblyLoadingUtils + { + /// + /// Make sure the given assemblies are loaded and that their loadable classes have been catalogued. + /// + public static void LoadAndRegister(IHostEnvironment env, string[] assemblies) + { + Contracts.AssertValue(env); + + if (Utils.Size(assemblies) > 0) + { + foreach (string path in assemblies) + { + Exception ex = null; + try + { + // REVIEW: Will LoadFrom ever return null? + Contracts.CheckNonEmpty(path, nameof(path)); + var assem = LoadAssembly(env, path); + if (assem != null) + continue; + } + catch (Exception e) + { + ex = e; + } + + // If it is a zip file, load it that way. + ZipArchive zip; + try + { + zip = ZipFile.OpenRead(path); + } + catch (Exception e) + { + // Couldn't load as an assembly and not a zip, so warn the user. + ex = ex ?? e; + Console.Error.WriteLine("Warning: Could not load '{0}': {1}", path, ex.Message); + continue; + } + + string dir; + try + { + dir = CreateTempDirectory(); + } + catch (Exception e) + { + throw Contracts.ExceptIO(e, "Creating temp directory for extra assembly zip extraction failed: '{0}'", path); + } + + try + { + zip.ExtractToDirectory(dir); + } + catch (Exception e) + { + throw Contracts.ExceptIO(e, "Extracting extra assembly zip failed: '{0}'", path); + } + + LoadAssembliesInDir(env, dir, false); + } + } + } + + public static IDisposable CreateAssemblyRegistrar(IHostEnvironment env, string loadAssembliesPath = null) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValueOrNull(loadAssembliesPath); + + return new AssemblyRegistrar(env, loadAssembliesPath); + } + + public static void RegisterCurrentLoadedAssemblies(IHostEnvironment env) + { + Contracts.CheckValue(env, nameof(env)); + + foreach (Assembly a in AppDomain.CurrentDomain.GetAssemblies()) + { + TryRegisterAssembly(env.ComponentCatalog, a); + } + } + + private static string CreateTempDirectory() + { + string dir = GetTempPath(); + Directory.CreateDirectory(dir); + return dir; + } + + private static string GetTempPath() + { + Guid guid = Guid.NewGuid(); + return Path.GetFullPath(Path.Combine(Path.GetTempPath(), "MLNET_" + guid.ToString())); + } + + private static readonly string[] _filePrefixesToAvoid = new string[] { + "api-ms-win", + "clr", + "coreclr", + "dbgshim", + "ext-ms-win", + "microsoft.bond.", + "microsoft.cosmos.", + "microsoft.csharp", + "microsoft.data.", + "microsoft.hpc.", + "microsoft.live.", + "microsoft.platformbuilder.", + "microsoft.visualbasic", + "microsoft.visualstudio.", + "microsoft.win32", + "microsoft.windowsapicodepack.", + "microsoft.windowsazure.", + "mscor", + "msvc", + "petzold.", + "roslyn.", + "sho", + "sni", + "sqm", + "system.", + "zlib", + }; + + private static bool ShouldSkipPath(string path) + { + string name = Path.GetFileName(path).ToLowerInvariant(); + switch (name) + { + case "cqo.dll": + case "fasttreenative.dll": + case "libiomp5md.dll": + case "libvw.dll": + case "matrixinterf.dll": + case "microsoft.ml.neuralnetworks.gpucuda.dll": + case "mklimports.dll": + case "microsoft.research.controls.decisiontrees.dll": + case "microsoft.ml.neuralnetworks.sse.dll": + case "neuraltreeevaluator.dll": + case "optimizationbuilderdotnet.dll": + case "parallelcommunicator.dll": + case "microsoft.ml.runtime.runtests.dll": + case "scopecompiler.dll": + case "tbb.dll": + case "internallearnscope.dll": + case "unmanagedlib.dll": + case "vcclient.dll": + case "libxgboost.dll": + case "zedgraph.dll": + case "__scopecodegen__.dll": + case "cosmosClientApi.dll": + return true; + } + + foreach (var s in _filePrefixesToAvoid) + { + if (name.StartsWith(s, StringComparison.OrdinalIgnoreCase)) + return true; + } + + return false; + } + + private static void LoadAssembliesInDir(IHostEnvironment env, string dir, bool filter) + { + if (!Directory.Exists(dir)) + return; + + using (var ch = env.Start("LoadAssembliesInDir")) + { + // Load all dlls in the given directory. + var paths = Directory.EnumerateFiles(dir, "*.dll"); + foreach (string path in paths) + { + if (filter && ShouldSkipPath(path)) + { + ch.Info($"Skipping assembly '{path}' because its name was filtered."); + continue; + } + + LoadAssembly(env, path); + } + } + } + + /// + /// Given an assembly path, load the assembly and register it with the ComponentCatalog. + /// + private static Assembly LoadAssembly(IHostEnvironment env, string path) + { + Assembly assembly = null; + try + { + assembly = Assembly.LoadFrom(path); + } + catch (Exception e) + { + using (var ch = env.Start("LoadAssembly")) + { + ch.Error("Could not load assembly {0}:\n{1}", path, e.ToString()); + } + return null; + } + + if (assembly != null) + { + TryRegisterAssembly(env.ComponentCatalog, assembly); + } + + return assembly; + } + + /// + /// Checks whether references the assembly containing LoadableClassAttributeBase, + /// and therefore can contain components. + /// + private static bool CanContainComponents(Assembly assembly) + { + var targetFullName = typeof(LoadableClassAttributeBase).Assembly.GetName().FullName; + + bool found = false; + foreach (var name in assembly.GetReferencedAssemblies()) + { + if (name.FullName == targetFullName) + { + found = true; + break; + } + } + + return found; + } + + private static void TryRegisterAssembly(ComponentCatalog catalog, Assembly assembly) + { + // Don't try to index dynamic generated assembly + if (assembly.IsDynamic) + return; + + if (!CanContainComponents(assembly)) + return; + + catalog.RegisterAssembly(assembly); + } + + private sealed class AssemblyRegistrar : IDisposable + { + private readonly IHostEnvironment _env; + + public AssemblyRegistrar(IHostEnvironment env, string path) + { + _env = env; + + RegisterCurrentLoadedAssemblies(_env); + + if (!string.IsNullOrEmpty(path)) + { + LoadAssembliesInDir(_env, path, true); + path = Path.Combine(path, "AutoLoad"); + LoadAssembliesInDir(_env, path, true); + } + + AppDomain.CurrentDomain.AssemblyLoad += CurrentDomainAssemblyLoad; + } + + public void Dispose() + { + AppDomain.CurrentDomain.AssemblyLoad -= CurrentDomainAssemblyLoad; + } + + private void CurrentDomainAssemblyLoad(object sender, AssemblyLoadEventArgs args) + { + TryRegisterAssembly(_env.ComponentCatalog, args.LoadedAssembly); + } + } + } +} diff --git a/src/Microsoft.ML.Api/ComponentCreation.cs b/src/Microsoft.ML.Api/ComponentCreation.cs index 1694a1cf4f..f34b25235c 100644 --- a/src/Microsoft.ML.Api/ComponentCreation.cs +++ b/src/Microsoft.ML.Api/ComponentCreation.cs @@ -439,7 +439,7 @@ private static TRes CreateCore(IHostEnvironment env, TArgs ar { env.CheckValue(args, nameof(args)); - var classes = ComponentCatalog.FindLoadableClasses(); + var classes = env.ComponentCatalog.FindLoadableClasses(); if (classes.Length == 0) throw env.Except("Couldn't find a {0} class that accepts {1} as arguments.", typeof(TRes).Name, typeof(TArgs).FullName); if (classes.Length > 1) diff --git a/src/Microsoft.ML.Api/SerializableLambdaTransform.cs b/src/Microsoft.ML.Api/SerializableLambdaTransform.cs index 7de6e522d8..38bd58e76d 100644 --- a/src/Microsoft.ML.Api/SerializableLambdaTransform.cs +++ b/src/Microsoft.ML.Api/SerializableLambdaTransform.cs @@ -29,7 +29,8 @@ public static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(SerializableLambdaTransform).Assembly.FullName); } public const string LoaderSignature = "UserLambdaMapTransform"; diff --git a/src/Microsoft.ML.Core/CommandLine/CmdParser.cs b/src/Microsoft.ML.Core/CommandLine/CmdParser.cs index d5a204dd45..a25f8cecf6 100644 --- a/src/Microsoft.ML.Core/CommandLine/CmdParser.cs +++ b/src/Microsoft.ML.Core/CommandLine/CmdParser.cs @@ -414,7 +414,7 @@ public static string CombineSettings(string[] settings) } // REVIEW: Add a method for cloning arguments, instead of going to text and back. - public static string GetSettings(IExceptionContext ectx, object values, object defaults, SettingsFlags flags = SettingsFlags.Default) + public static string GetSettings(IHostEnvironment env, object values, object defaults, SettingsFlags flags = SettingsFlags.Default) { Type t1 = values.GetType(); Type t2 = defaults.GetType(); @@ -430,7 +430,7 @@ public static string GetSettings(IExceptionContext ectx, object values, object d return null; var info = GetArgumentInfo(t, defaults); - return GetSettingsCore(ectx, info, values, flags); + return GetSettingsCore(env, info, values, flags); } public static IEnumerable> GetSettingPairs(IHostEnvironment env, object values, object defaults, SettingsFlags flags = SettingsFlags.None) @@ -862,20 +862,20 @@ private bool Parse(ArgumentInfo info, string[] strs, object destination) return !hadError; } - private static string GetSettingsCore(IExceptionContext ectx, ArgumentInfo info, object values, SettingsFlags flags) + private static string GetSettingsCore(IHostEnvironment env, ArgumentInfo info, object values, SettingsFlags flags) { StringBuilder sb = new StringBuilder(); if (info.ArgDef != null) { var val = info.ArgDef.GetValue(values); - info.ArgDef.AppendSetting(ectx, sb, val, flags); + info.ArgDef.AppendSetting(env, sb, val, flags); } foreach (Argument arg in info.Args) { var val = arg.GetValue(values); - arg.AppendSetting(ectx, sb, val, flags); + arg.AppendSetting(env, sb, val, flags); } return sb.ToString(); @@ -886,7 +886,7 @@ private static string GetSettingsCore(IExceptionContext ectx, ArgumentInfo info, /// It deals with custom "unparse" functionality, as well as quoting. It also appends to a StringBuilder /// instead of returning a string. /// - private static void AppendCustomItem(IExceptionContext ectx, ArgumentInfo info, object values, SettingsFlags flags, StringBuilder sb) + private static void AppendCustomItem(IHostEnvironment env, ArgumentInfo info, object values, SettingsFlags flags, StringBuilder sb) { int ich = sb.Length; // We always call unparse, even when NoUnparse is specified, since Unparse can "cleanse", which @@ -902,13 +902,13 @@ private static void AppendCustomItem(IExceptionContext ectx, ArgumentInfo info, if (info.ArgDef != null) { var val = info.ArgDef.GetValue(values); - info.ArgDef.AppendSetting(ectx, sb, val, flags); + info.ArgDef.AppendSetting(env, sb, val, flags); } foreach (Argument arg in info.Args) { var val = arg.GetValue(values); - arg.AppendSetting(ectx, sb, val, flags); + arg.AppendSetting(env, sb, val, flags); } string str = sb.ToString(ich, sb.Length - ich); @@ -916,14 +916,14 @@ private static void AppendCustomItem(IExceptionContext ectx, ArgumentInfo info, CmdQuoter.QuoteValue(str, sb, force: true); } - private IEnumerable> GetSettingPairsCore(IExceptionContext ectx, ArgumentInfo info, object values, SettingsFlags flags) + private IEnumerable> GetSettingPairsCore(IHostEnvironment env, ArgumentInfo info, object values, SettingsFlags flags) { StringBuilder buffer = new StringBuilder(); foreach (Argument arg in info.Args) { string key = arg.GetKey(flags); object value = arg.GetValue(values); - foreach (string val in arg.GetSettingStrings(ectx, value, buffer)) + foreach (string val in arg.GetSettingStrings(env, value, buffer)) yield return new KeyValuePair(key, val); } } @@ -943,13 +943,13 @@ public ArgumentHelpStrings(string syntax, string help) /// /// A user friendly usage string describing the command line argument syntax. /// - private string GetUsageString(IExceptionContext ectx, ArgumentInfo info, bool showRsp = true, int? columns = null) + private string GetUsageString(IHostEnvironment env, ArgumentInfo info, bool showRsp = true, int? columns = null) { int screenWidth = columns ?? Console.BufferWidth; if (screenWidth <= 0) screenWidth = 80; - ArgumentHelpStrings[] strings = GetAllHelpStrings(ectx, info, showRsp); + ArgumentHelpStrings[] strings = GetAllHelpStrings(env, info, showRsp); int maxParamLen = 0; foreach (ArgumentHelpStrings helpString in strings) @@ -1039,17 +1039,17 @@ private static void AddNewLine(string newLine, StringBuilder builder, ref int cu currentColumn = 0; } - private static ArgumentHelpStrings[] GetAllHelpStrings(IExceptionContext ectx, ArgumentInfo info, bool showRsp) + private ArgumentHelpStrings[] GetAllHelpStrings(IHostEnvironment env, ArgumentInfo info, bool showRsp) { List strings = new List(); if (info.ArgDef != null) - strings.Add(GetHelpStrings(ectx, info.ArgDef)); + strings.Add(GetHelpStrings(env, info.ArgDef)); foreach (Argument arg in info.Args) { if (!arg.IsHidden) - strings.Add(GetHelpStrings(ectx, arg)); + strings.Add(GetHelpStrings(env, arg)); } if (showRsp) @@ -1058,9 +1058,9 @@ private static ArgumentHelpStrings[] GetAllHelpStrings(IExceptionContext ectx, A return strings.ToArray(); } - private static ArgumentHelpStrings GetHelpStrings(IExceptionContext ectx, Argument arg) + private ArgumentHelpStrings GetHelpStrings(IHostEnvironment env, Argument arg) { - return new ArgumentHelpStrings(arg.GetSyntaxHelp(), arg.GetFullHelpText(ectx)); + return new ArgumentHelpStrings(arg.GetSyntaxHelp(), arg.GetFullHelpText(env, this)); } private bool LexFileArguments(string fileName, out string[] arguments) @@ -1994,7 +1994,7 @@ private bool ParseValue(CmdParser owner, string data, out object value) return false; } - private void AppendHelpValue(IExceptionContext ectx, StringBuilder builder, object value) + private void AppendHelpValue(IHostEnvironment env, CmdParser owner, StringBuilder builder, object value) { if (value == null) builder.Append("{}"); @@ -2006,19 +2006,19 @@ private void AppendHelpValue(IExceptionContext ectx, StringBuilder builder, obje foreach (object o in (System.Array)value) { builder.Append(pre); - AppendHelpValue(ectx, builder, o); + AppendHelpValue(env, owner, builder, o); pre = ", "; } } else if (value is IComponentFactory) { string name; - var catalog = ModuleCatalog.CreateInstance(ectx); + var catalog = owner._catalog.Value; var type = value.GetType(); bool success = catalog.TryGetComponentShortName(type, out name); Contracts.Assert(success); - var settings = GetSettings(ectx, value, Activator.CreateInstance(type)); + var settings = GetSettings(env, value, Activator.CreateInstance(type)); builder.Append(name); if (!string.IsNullOrWhiteSpace(settings)) { @@ -2035,7 +2035,7 @@ private void AppendHelpValue(IExceptionContext ectx, StringBuilder builder, obje } // If value differs from the default, appends the setting to sb. - public void AppendSetting(IExceptionContext ectx, StringBuilder sb, object value, SettingsFlags flags) + public void AppendSetting(IHostEnvironment env, StringBuilder sb, object value, SettingsFlags flags) { object def = DefaultValue; if (!IsCollection) @@ -2043,13 +2043,13 @@ public void AppendSetting(IExceptionContext ectx, StringBuilder sb, object value if (value == null) { if (def != null || IsRequired) - AppendSettingCore(ectx, sb, "", flags); + AppendSettingCore(env, sb, "", flags); } else if (def == null || !value.Equals(def)) { var buffer = new StringBuilder(); - if (!(value is IComponentFactory) || (GetString(ectx, value, buffer) != GetString(ectx, def, buffer))) - AppendSettingCore(ectx, sb, value, flags); + if (!(value is IComponentFactory) || (GetString(env, value, buffer) != GetString(env, def, buffer))) + AppendSettingCore(env, sb, value, flags); } return; } @@ -2076,10 +2076,10 @@ public void AppendSetting(IExceptionContext ectx, StringBuilder sb, object value } foreach (object x in vals) - AppendSettingCore(ectx, sb, x, flags); + AppendSettingCore(env, sb, x, flags); } - private void AppendSettingCore(IExceptionContext ectx, StringBuilder sb, object value, SettingsFlags flags) + private void AppendSettingCore(IHostEnvironment env, StringBuilder sb, object value, SettingsFlags flags) { if (sb.Length > 0) sb.Append(" "); @@ -2100,11 +2100,11 @@ private void AppendSettingCore(IExceptionContext ectx, StringBuilder sb, object else if (value is bool) sb.Append((bool)value ? "+" : "-"); else if (IsCustomItemType) - AppendCustomItem(ectx, _infoCustom, value, flags, sb); + AppendCustomItem(env, _infoCustom, value, flags, sb); else if (IsComponentFactory) { var buffer = new StringBuilder(); - sb.Append(GetString(ectx, value, buffer)); + sb.Append(GetString(env, value, buffer)); } else sb.Append(value.ToString()); @@ -2129,7 +2129,7 @@ private void ExtractTag(object value, out string tag, out object newValue) // If value differs from the default, return the string representation of 'value', // or an IEnumerable of string representations if 'value' is an array. - public IEnumerable GetSettingStrings(IExceptionContext ectx, object value, StringBuilder buffer) + public IEnumerable GetSettingStrings(IHostEnvironment env, object value, StringBuilder buffer) { object def = DefaultValue; @@ -2138,12 +2138,12 @@ public IEnumerable GetSettingStrings(IExceptionContext ectx, object valu if (value == null) { if (def != null || IsRequired) - yield return GetString(ectx, value, buffer); + yield return GetString(env, value, buffer); } else if (def == null || !value.Equals(def)) { - if (!(value is IComponentFactory) || (GetString(ectx, value, buffer) != GetString(ectx, def, buffer))) - yield return GetString(ectx, value, buffer); + if (!(value is IComponentFactory) || (GetString(env, value, buffer) != GetString(env, def, buffer))) + yield return GetString(env, value, buffer); } yield break; } @@ -2171,10 +2171,10 @@ public IEnumerable GetSettingStrings(IExceptionContext ectx, object valu } foreach (object x in vals) - yield return GetString(ectx, x, buffer); + yield return GetString(env, x, buffer); } - private string GetString(IExceptionContext ectx, object value, StringBuilder buffer) + private string GetString(IHostEnvironment env, object value, StringBuilder buffer) { if (value == null) return "{}"; @@ -2192,12 +2192,13 @@ private string GetString(IExceptionContext ectx, object value, StringBuilder buf if (value is IComponentFactory) { string name; - var catalog = ModuleCatalog.CreateInstance(ectx); + // TODO: ModuleCatalog should be on env + var catalog = ModuleCatalog.CreateInstance(env); var type = value.GetType(); bool isModuleComponent = catalog.TryGetComponentShortName(type, out name); if (isModuleComponent) { - var settings = GetSettings(ectx, value, Activator.CreateInstance(type)); + var settings = GetSettings(env, value, Activator.CreateInstance(type)); buffer.Clear(); buffer.Append(name); if (!string.IsNullOrWhiteSpace(settings)) @@ -2208,20 +2209,12 @@ private string GetString(IExceptionContext ectx, object value, StringBuilder buf } return buffer.ToString(); } - else if (value is ICommandLineComponentFactory) - { - return value.ToString(); - } - else - { - throw ectx.Except($"IComponentFactory instances either need to be EntryPointComponents or implement {nameof(ICommandLineComponentFactory)}."); - } } return value.ToString(); } - public string GetFullHelpText(IExceptionContext ectx) + public string GetFullHelpText(IHostEnvironment env, CmdParser owner) { if (IsHidden) return null; @@ -2248,7 +2241,7 @@ public string GetFullHelpText(IExceptionContext ectx) if (builder.Length > 0) builder.Append(" "); builder.Append("Default value:'"); - AppendHelpValue(ectx, builder, DefaultValue); + AppendHelpValue(env, owner, builder, DefaultValue); builder.Append('\''); } if (Utils.Size(ShortNames) != 0) diff --git a/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs b/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs index 78ec31205f..c09c7c8e70 100644 --- a/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs +++ b/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs @@ -2,15 +2,13 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Runtime.CommandLine; +using Microsoft.ML.Runtime.Internal.Utilities; using System; using System.Collections.Concurrent; using System.Collections.Generic; -using System.IO; -using System.IO.Compression; using System.Linq; using System.Reflection; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.CommandLine; // REVIEW: Determine ideal namespace. namespace Microsoft.ML.Runtime @@ -22,8 +20,17 @@ namespace Microsoft.ML.Runtime /// types for component instantiation. Each component may also specify an "arguments object" that should /// be provided at instantiation time. /// - public static class ComponentCatalog + public sealed class ComponentCatalog { + internal ComponentCatalog() + { + _lock = new object(); + _cachedAssemblies = new HashSet(); + _classesByKey = new ConcurrentDictionary(); + _classes = new ConcurrentQueue(); + _signatures = new ConcurrentDictionary(); + } + /// /// Provides information on an instantiatable component, aka, loadable class. /// @@ -248,201 +255,18 @@ public object CreateArguments() } } - /// - /// Debug reporting level. - /// - public static int DebugLevel = 1; - - // Do not initialize this one - the initial null value is used as a "flag" to prime things. - private static ConcurrentQueue _assemblyQueue; - - // The assemblies that are loaded by Reflection.LoadAssembly or Assembly.Load* after we started tracking - // the load events. We will provide assembly resolving for these assemblies. This is created simultaneously - // with s_assemblyQueue. - private static ConcurrentDictionary _loadedAssemblies; - - // This lock protects s_cachedAssemblies and s_cachedPaths only. The collection of ClassInfos is concurrent + // This lock protects _cachedAssemblies only. The collection of ClassInfos is concurrent // so needs no protection. - private static object _lock = new object(); - private static HashSet _cachedAssemblies = new HashSet(); - private static HashSet _cachedPaths = new HashSet(); + private readonly object _lock; + private readonly HashSet _cachedAssemblies; // Map from key/name to loadable class. Note that the same ClassInfo may appear // multiple times. For the set of unique infos, use s_classes. - private static ConcurrentDictionary _classesByKey = new ConcurrentDictionary(); + private readonly ConcurrentDictionary _classesByKey; // The unique ClassInfos and Signatures. - private static ConcurrentQueue _classes = new ConcurrentQueue(); - private static ConcurrentDictionary _signatures = new ConcurrentDictionary(); - - public static string[] FilePrefixesToAvoid = new string[] { - "api-ms-win", - "clr", - "coreclr", - "dbgshim", - "ext-ms-win", - "microsoft.bond.", - "microsoft.cosmos.", - "microsoft.csharp", - "microsoft.data.", - "microsoft.hpc.", - "microsoft.live.", - "microsoft.platformbuilder.", - "microsoft.visualbasic", - "microsoft.visualstudio.", - "microsoft.win32", - "microsoft.windowsapicodepack.", - "microsoft.windowsazure.", - "mscor", - "msvc", - "petzold.", - "roslyn.", - "sho", - "sni", - "sqm", - "system.", - "zlib", - }; - - private static bool ShouldSkipPath(string path) - { - string name = Path.GetFileName(path).ToLowerInvariant(); - switch (name) - { - case "cqo.dll": - case "fasttreenative.dll": - case "libiomp5md.dll": - case "libvw.dll": - case "matrixinterf.dll": - case "microsoft.ml.neuralnetworks.gpucuda.dll": - case "mklimports.dll": - case "microsoft.research.controls.decisiontrees.dll": - case "microsoft.ml.neuralnetworks.sse.dll": - case "neuraltreeevaluator.dll": - case "optimizationbuilderdotnet.dll": - case "parallelcommunicator.dll": - case "microsoft.ml.runtime.runtests.dll": - case "scopecompiler.dll": - case "tbb.dll": - case "internallearnscope.dll": - case "unmanagedlib.dll": - case "vcclient.dll": - case "libxgboost.dll": - case "zedgraph.dll": - case "__scopecodegen__.dll": - case "cosmosClientApi.dll": - return true; - } - - foreach (var s in FilePrefixesToAvoid) - { - if (name.StartsWith(s)) - return true; - } - - return false; - } - - /// - /// This loads assemblies that are in our "root" directory (where this assembly is) and caches - /// information for the loadable classes in loaded assemblies. - /// - private static void CacheLoadedAssemblies() - { - // The target assembly is the one containing LoadableClassAttributeBase. If an assembly doesn't reference - // the target, then we don't want to scan its assembly attributes (there's no point in doing so). - var target = typeof(LoadableClassAttributeBase).Assembly; - - lock (_lock) - { - if (_assemblyQueue == null) - { - // Create the loaded assembly queue and dictionary, set up the AssemblyLoad / AssemblyResolve - // event handlers and populate the queue / dictionary with all assemblies that are currently loaded. - Contracts.Assert(_assemblyQueue == null); - Contracts.Assert(_loadedAssemblies == null); - - _assemblyQueue = new ConcurrentQueue(); - _loadedAssemblies = new ConcurrentDictionary(); - - AppDomain.CurrentDomain.AssemblyLoad += CurrentDomainAssemblyLoad; - AppDomain.CurrentDomain.AssemblyResolve += CurrentDomainAssemblyResolve; - - foreach (Assembly a in AppDomain.CurrentDomain.GetAssemblies()) - { - // Ignore dynamic assemblies. - if (a.IsDynamic) - continue; - - _assemblyQueue.Enqueue(a); - if (!_loadedAssemblies.TryAdd(a.FullName, a)) - { - // Duplicate loading. - Console.Error.WriteLine("Duplicate loaded assembly '{0}'", a.FullName); - } - } - - // Load all assemblies in our directory. - var moduleName = typeof(ComponentCatalog).Module.FullyQualifiedName; - - // If were are loaded in the context of SQL CLR then the FullyQualifiedName and Name properties are set to - // string "" and we skip scanning current directory. - if (moduleName != "") - { - string dir = Path.GetDirectoryName(moduleName); - LoadAssembliesInDir(dir, true); - dir = Path.Combine(dir, "AutoLoad"); - LoadAssembliesInDir(dir, true); - } - } - - Contracts.AssertValue(_assemblyQueue); - Contracts.AssertValue(_loadedAssemblies); - - Assembly assembly; - while (_assemblyQueue.TryDequeue(out assembly)) - { - if (!_cachedAssemblies.Add(assembly.FullName)) - continue; - - if (assembly != target) - { - bool found = false; - var targetName = target.GetName(); - foreach (var name in assembly.GetReferencedAssemblies()) - { - if (name.Name == targetName.Name) - { - found = true; - break; - } - } - if (!found) - continue; - } - - int added = 0; - foreach (LoadableClassAttributeBase attr in assembly.GetCustomAttributes(typeof(LoadableClassAttributeBase))) - { - MethodInfo getter = null; - ConstructorInfo ctor = null; - MethodInfo create = null; - bool requireEnvironment = false; - if (attr.InstanceType != typeof(void) && !TryGetIniters(attr.InstanceType, attr.LoaderType, attr.CtorTypes, out getter, out ctor, out create, out requireEnvironment)) - { - Console.Error.WriteLine( - "CacheClassesFromAssembly: can't instantiate loadable class {0} with name {1}", - attr.InstanceType.Name, attr.LoadNames[0]); - Contracts.Assert(getter == null && ctor == null && create == null); - } - var info = new LoadableClassInfo(attr, getter, ctor, create, requireEnvironment); - - AddClass(info, attr.LoadNames); - added++; - } - } - } - } + private readonly ConcurrentQueue _classes; + private readonly ConcurrentDictionary _signatures; private static bool TryGetIniters(Type instType, Type loaderType, Type[] parmTypes, out MethodInfo getter, out ConstructorInfo ctor, out MethodInfo create, out bool requireEnvironment) @@ -472,7 +296,7 @@ private static bool TryGetIniters(Type instType, Type loaderType, Type[] parmTyp return false; } - private static void AddClass(LoadableClassInfo info, string[] loadNames) + private void AddClass(LoadableClassInfo info, string[] loadNames) { _classes.Enqueue(info); foreach (var sigType in info.SignatureTypes) @@ -528,157 +352,53 @@ private static MethodInfo FindCreateMethod(Type instType, Type loaderType, Type[ return meth; } - private static void LoadAssembliesInDir(string dir, bool filter) - { - if (!Directory.Exists(dir)) - return; - - // Load all dlls in the given directory. - var paths = Directory.EnumerateFiles(dir, "*.dll"); - foreach (string path in paths) - { - if (filter && ShouldSkipPath(path)) - continue; - // Loading the assembly is enough because of our event handler. - LoadAssembly(path); - } - } - - private static void CurrentDomainAssemblyLoad(object sender, AssemblyLoadEventArgs args) - { - // Don't try to index dynamic generated assembly - if (args.LoadedAssembly.IsDynamic) - return; - _assemblyQueue.Enqueue(args.LoadedAssembly); - if (!_loadedAssemblies.TryAdd(args.LoadedAssembly.FullName, args.LoadedAssembly)) - { - // Duplicate loading. - Console.Error.WriteLine("Duplicate loading of the assembly '{0}'", args.LoadedAssembly.FullName); - } - } - - private static Assembly CurrentDomainAssemblyResolve(object sender, ResolveEventArgs args) - { - // REVIEW: currently, the resolving happens on exact matches of the full name. - // This has proved to work with the C# transform. We might need to change the resolving logic when the need arises. - Assembly found; - if (_loadedAssemblies.TryGetValue(args.Name, out found)) - return found; - return null; - } - /// - /// Given an assembly path, load the assembly. + /// Registers all the components in the specified assembly by looking for loadable classes + /// and adding them to the catalog. /// - public static Assembly LoadAssembly(string path) + /// + /// The assembly to register. + /// + /// + /// true to throw an exception if there are errors with registering the components; + /// false to skip any errors. + /// + public void RegisterAssembly(Assembly assembly, bool throwOnError = true) { - try - { - return LoadFrom(path); - } - catch (Exception e) - { - if (DebugLevel > 2) - Console.Error.WriteLine("Could not load assembly {0}:\n{1}", path, e.ToString()); - return null; - } - } - - private static Assembly LoadFrom(string path) - { - Contracts.CheckNonEmpty(path, nameof(path)); - return Assembly.LoadFrom(path); - } - - /// - /// Make sure the given assemblies are loaded and that their loadable classes have been catalogued. - /// - public static void CacheClassesExtra(string[] assemblies) - { - if (Utils.Size(assemblies) > 0) + lock (_lock) { - lock (_lock) + if (_cachedAssemblies.Add(assembly.FullName)) { - foreach (string path in assemblies) + foreach (LoadableClassAttributeBase attr in assembly.GetCustomAttributes(typeof(LoadableClassAttributeBase))) { - if (!_cachedPaths.Add(path)) - continue; - - Exception ex = null; - try - { - // REVIEW: Will LoadFrom ever return null? - var assem = LoadFrom(path); - if (assem != null) - continue; - } - catch (Exception e) - { - ex = e; - } - - // If it is a zip file, load it that way. - ZipArchive zip; - try - { - zip = ZipFile.OpenRead(path); - } - catch (Exception e) - { - // Couldn't load as an assembly and not a zip, so warn the user. - ex = ex ?? e; - Console.Error.WriteLine("Warning: Could not load '{0}': {1}", path, ex.Message); - continue; - } - - string dir; - try - { - dir = CreateTempDirectory(); - } - catch (Exception e) - { - throw Contracts.ExceptIO(e, "Creating temp directory for extra assembly zip extraction failed: '{0}'", path); - } - - try - { - zip.ExtractToDirectory(dir); - } - catch (Exception e) + MethodInfo getter = null; + ConstructorInfo ctor = null; + MethodInfo create = null; + bool requireEnvironment = false; + if (attr.InstanceType != typeof(void) && !TryGetIniters(attr.InstanceType, attr.LoaderType, attr.CtorTypes, out getter, out ctor, out create, out requireEnvironment)) { - throw Contracts.ExceptIO(e, "Extracting extra assembly zip failed: '{0}'", path); + if (throwOnError) + { + throw new InvalidOperationException(string.Format( + "Can't instantiate loadable class {0} with name {1}", + attr.InstanceType.Name, attr.LoadNames[0])); + } + Contracts.Assert(getter == null && ctor == null && create == null); } + var info = new LoadableClassInfo(attr, getter, ctor, create, requireEnvironment); - LoadAssembliesInDir(dir, false); + AddClass(info, attr.LoadNames); } } } - - CacheLoadedAssemblies(); - } - - private static string CreateTempDirectory() - { - string dir = GetTempPath(); - Directory.CreateDirectory(dir); - return dir; - } - - private static string GetTempPath() - { - Guid guid = Guid.NewGuid(); - return Path.GetFullPath(Path.Combine(Path.GetTempPath(), "TLC_" + guid.ToString())); } /// /// Return an array containing information for all instantiatable components. /// If provided, the given set of assemblies is loaded first. /// - public static LoadableClassInfo[] GetAllClasses(string[] assemblies = null) + public LoadableClassInfo[] GetAllClasses() { - CacheClassesExtra(assemblies); - return _classes.ToArray(); } @@ -686,13 +406,11 @@ public static LoadableClassInfo[] GetAllClasses(string[] assemblies = null) /// Return an array containing information for instantiatable components with the given /// signature and base type. If provided, the given set of assemblies is loaded first. /// - public static LoadableClassInfo[] GetAllDerivedClasses(Type typeBase, Type typeSig, string[] assemblies = null) + public LoadableClassInfo[] GetAllDerivedClasses(Type typeBase, Type typeSig) { Contracts.CheckValue(typeBase, nameof(typeBase)); Contracts.CheckValueOrNull(typeSig); - CacheClassesExtra(assemblies); - // Apply the default. if (typeSig == null) typeSig = typeof(SignatureDefault); @@ -706,10 +424,8 @@ public static LoadableClassInfo[] GetAllDerivedClasses(Type typeBase, Type typeS /// Return an array containing all the known signature types. If provided, the given set of assemblies /// is loaded first. /// - public static Type[] GetAllSignatureTypes(string[] assemblies = null) + public Type[] GetAllSignatureTypes() { - CacheClassesExtra(assemblies); - return _signatures.Select(kvp => kvp.Key).ToArray(); } @@ -726,58 +442,47 @@ public static string SignatureToString(Type sig) return kind; } - private static LoadableClassInfo FindClassCore(LoadableClassInfo.Key key) + private LoadableClassInfo FindClassCore(LoadableClassInfo.Key key) { LoadableClassInfo info; - if (_classesByKey.TryGetValue(key, out info)) - return info; - - CacheLoadedAssemblies(); - if (_classesByKey.TryGetValue(key, out info)) return info; return null; } - public static LoadableClassInfo[] FindLoadableClasses(string name) + public LoadableClassInfo[] FindLoadableClasses(string name) { name = name.ToLowerInvariant().Trim(); - CacheLoadedAssemblies(); - var res = _classes .Where(ci => ci.LoadNames.Select(n => n.ToLowerInvariant().Trim()).Contains(name)) .ToArray(); return res; } - public static LoadableClassInfo[] FindLoadableClasses() + public LoadableClassInfo[] FindLoadableClasses() { - CacheLoadedAssemblies(); - return _classes .Where(ci => ci.SignatureTypes.Contains(typeof(TSig))) .ToArray(); } - public static LoadableClassInfo[] FindLoadableClasses() + public LoadableClassInfo[] FindLoadableClasses() { // REVIEW: this and above methods perform a linear search over all the loadable classes. // On 6/15/2015, TLC release build contained 431 of them, so adding extra lookups looks unnecessary at this time. - CacheLoadedAssemblies(); - return _classes .Where(ci => ci.ArgType == typeof(TArgs) && ci.SignatureTypes.Contains(typeof(TSig))) .ToArray(); } - public static LoadableClassInfo GetLoadableClassInfo(string loadName) + public LoadableClassInfo GetLoadableClassInfo(string loadName) { return GetLoadableClassInfo(loadName, typeof(TSig)); } - public static LoadableClassInfo GetLoadableClassInfo(string loadName, Type signatureType) + public LoadableClassInfo GetLoadableClassInfo(string loadName, Type signatureType) { Contracts.CheckParam(signatureType.BaseType == typeof(MulticastDelegate), nameof(signatureType), "signatureType must be a delegate type"); Contracts.CheckValueOrNull(loadName); @@ -815,7 +520,7 @@ private static bool TryCreateInstance(IHostEnvironment env, Type signature env.CheckValueOrNull(name); string nameLower = (name ?? "").ToLowerInvariant().Trim(); - LoadableClassInfo info = FindClassCore(new LoadableClassInfo.Key(nameLower, signatureType)); + LoadableClassInfo info = env.ComponentCatalog.FindClassCore(new LoadableClassInfo.Key(nameLower, signatureType)); if (info == null) { result = null; diff --git a/src/Microsoft.ML.Core/Data/IHostEnvironment.cs b/src/Microsoft.ML.Core/Data/IHostEnvironment.cs index b463e52a8e..bd08a75c9a 100644 --- a/src/Microsoft.ML.Core/Data/IHostEnvironment.cs +++ b/src/Microsoft.ML.Core/Data/IHostEnvironment.cs @@ -46,6 +46,11 @@ public interface IHostEnvironment : IChannelProvider, IProgressChannelProvider /// bool IsCancelled { get; } + /// + /// The catalog of loadable components () that are available in this host. + /// + ComponentCatalog ComponentCatalog { get; } + /// /// Return a file handle for an input "file". /// diff --git a/src/Microsoft.ML.Core/Data/ServerChannel.cs b/src/Microsoft.ML.Core/Data/ServerChannel.cs index fd0e1f2603..4b6db41541 100644 --- a/src/Microsoft.ML.Core/Data/ServerChannel.cs +++ b/src/Microsoft.ML.Core/Data/ServerChannel.cs @@ -171,16 +171,16 @@ public interface IServer : IDisposable /// for example, if a user opted to remove all implementations of and /// the associated for security reasons. /// - public static IServerFactory CreateDefaultServerFactoryOrNull(IExceptionContext ectx) + public static IServerFactory CreateDefaultServerFactoryOrNull(IHostEnvironment env) { - Contracts.CheckValue(ectx, nameof(ectx)); + Contracts.CheckValue(env, nameof(env)); // REVIEW: There should be a better way. There currently isn't, // but there should be. This is pretty horrifying, but it is preferable to // the alternative of having core components depend on an actual server // implementation, since we want those to be removable because of security // concerns in certain environments (since not everyone will be wild about // web servers popping up everywhere). - var cat = ModuleCatalog.CreateInstance(ectx); + var cat = ModuleCatalog.CreateInstance(env); ModuleCatalog.ComponentInfo component; if (!cat.TryFindComponent(typeof(IServerFactory), "mini", out component)) return null; diff --git a/src/Microsoft.ML.Core/EntryPoints/ModuleCatalog.cs b/src/Microsoft.ML.Core/EntryPoints/ModuleCatalog.cs index 60511bfd39..a34242b6f4 100644 --- a/src/Microsoft.ML.Core/EntryPoints/ModuleCatalog.cs +++ b/src/Microsoft.ML.Core/EntryPoints/ModuleCatalog.cs @@ -152,7 +152,6 @@ internal ComponentInfo(IExceptionContext ectx, Type interfaceType, string kind, } } - private static volatile ModuleCatalog _instance; private readonly EntryPointInfo[] _entryPoints; private readonly Dictionary _entryPointMap; @@ -167,15 +166,15 @@ public IEnumerable AllEntryPoints() return _entryPoints.AsEnumerable(); } - private ModuleCatalog(IExceptionContext ectx) + private ModuleCatalog(IHostEnvironment env) { - Contracts.AssertValue(ectx); + Contracts.AssertValue(env); _entryPointMap = new Dictionary(); _componentMap = new Dictionary(); _components = new List(); - var moduleClasses = ComponentCatalog.FindLoadableClasses(); + var moduleClasses = env.ComponentCatalog.FindLoadableClasses(); var entryPoints = new List(); foreach (var lc in moduleClasses) @@ -189,7 +188,7 @@ private ModuleCatalog(IExceptionContext ectx) if (attr == null) continue; - var info = new EntryPointInfo(ectx, methodInfo, attr, + var info = new EntryPointInfo(env, methodInfo, attr, methodInfo.GetCustomAttributes(typeof(ObsoleteAttribute), false).FirstOrDefault() as ObsoleteAttribute); entryPoints.Add(info); @@ -205,9 +204,9 @@ private ModuleCatalog(IExceptionContext ectx) // Scan for components. // First scan ourself, and then all nested types, for component info. - ScanForComponents(ectx, type); + ScanForComponents(env, type); foreach (var nestedType in type.GetTypeInfo().GetNestedTypes()) - ScanForComponents(ectx, nestedType); + ScanForComponents(env, nestedType); } _entryPoints = entryPoints.ToArray(); } @@ -274,17 +273,13 @@ private static bool IsValidName(string name) } /// - /// Create a module catalog (or reuse the one created before). + /// Create a module catalog. /// - /// The exception context to use to report errors while scanning the assemblies. - public static ModuleCatalog CreateInstance(IExceptionContext ectx) + /// The host environment and exception context to use to report errors while scanning the assemblies. + public static ModuleCatalog CreateInstance(IHostEnvironment env) { - Contracts.CheckValueOrNull(ectx); -#pragma warning disable 420 // volatile with Interlocked.CompareExchange. - if (_instance == null) - Interlocked.CompareExchange(ref _instance, new ModuleCatalog(ectx), null); -#pragma warning restore 420 - return _instance; + Contracts.CheckValue(env, nameof(env)); + return new ModuleCatalog(env); } public bool TryFindEntryPoint(string name, out EntryPointInfo entryPoint) diff --git a/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs b/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs index d4ff5ccd96..b195de6108 100644 --- a/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs +++ b/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs @@ -6,7 +6,6 @@ using System.Collections.Concurrent; using System.Collections.Generic; using System.IO; -using System.Threading; namespace Microsoft.ML.Runtime.Data { @@ -382,6 +381,8 @@ public void RemoveListener(Action listenerFunc) public bool IsCancelled { get; protected set; } + public ComponentCatalog ComponentCatalog { get; } + public override int Depth => 0; protected bool IsDisposed => _tempFiles == null; @@ -402,6 +403,7 @@ protected HostEnvironmentBase(IRandom rand, bool verbose, int conc, _tempLock = new object(); _tempFiles = new List(); Root = this as TEnv; + ComponentCatalog = new ComponentCatalog(); } /// @@ -422,6 +424,7 @@ protected HostEnvironmentBase(HostEnvironmentBase source, IRandom rand, bo Root = source.Root; ListenerDict = source.ListenerDict; ProgressTracker = source.ProgressTracker; + ComponentCatalog = source.ComponentCatalog; } /// diff --git a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs index f6670394ee..837c3beb3e 100644 --- a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs +++ b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs @@ -119,7 +119,7 @@ public override void Run() using (var ch = Host.Start(LoadName)) using (var server = InitServer(ch)) { - var settings = CmdParser.GetSettings(ch, Args, new Arguments()); + var settings = CmdParser.GetSettings(Host, Args, new Arguments()); string cmd = string.Format("maml.exe {0} {1}", LoadName, settings); ch.Info(cmd); @@ -557,7 +557,7 @@ private FoldResult RunFold(int fold) var bindable = ScoreUtils.GetSchemaBindableMapper(host, predictor, scorerFactorySettings: _scorer as ICommandLineComponentFactory); ch.AssertValue(bindable); var mapper = bindable.Bind(host, testData.Schema); - var scorerComp = _scorer ?? ScoreUtils.GetScorerComponent(mapper); + var scorerComp = _scorer ?? ScoreUtils.GetScorerComponent(host, mapper); IDataScorerTransform scorePipe = scorerComp.CreateComponent(host, testData.Data, mapper, trainData.Schema); // Save per-fold model. diff --git a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs index 301603b14a..014fdf87f8 100644 --- a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs +++ b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs @@ -120,7 +120,7 @@ private void RunCore(IChannel ch) var mapper = bindable.Bind(Host, schema); if (scorer == null) - scorer = ScoreUtils.GetScorerComponent(mapper); + scorer = ScoreUtils.GetScorerComponent(Host, mapper); loader = CompositeDataLoader.ApplyTransform(Host, loader, "Scorer", scorer.ToString(), (env, view) => scorer.CreateComponent(env, view, mapper, trainSchema)); @@ -284,7 +284,7 @@ private static TScorerFactory GetScorerComponentAndMapper( mapper = bindable.Bind(env, schema); if (scorerFactory != null) return scorerFactory; - return GetScorerComponent(mapper); + return GetScorerComponent(env, mapper); } /// @@ -292,12 +292,15 @@ private static TScorerFactory GetScorerComponentAndMapper( /// metadata on the first column of the mapper. If that text is found and maps to a scorer loadable class, /// that component is used. Otherwise, the GenericScorer is used. /// + /// The host environment.. /// The schema bound mapper to get the default scorer.. /// An optional suffix to append to the default column names. public static TScorerFactory GetScorerComponent( + IHostEnvironment environment, ISchemaBoundMapper mapper, string suffix = null) { + Contracts.CheckValue(environment, nameof(environment)); Contracts.AssertValue(mapper); ComponentCatalog.LoadableClassInfo info = null; @@ -307,7 +310,7 @@ public static TScorerFactory GetScorerComponent( !scoreKind.IsEmpty) { var loadName = scoreKind.ToString(); - info = ComponentCatalog.GetLoadableClassInfo(loadName); + info = environment.ComponentCatalog.GetLoadableClassInfo(loadName); if (info == null || !typeof(IDataScorerTransform).IsAssignableFrom(info.Type)) info = null; } diff --git a/src/Microsoft.ML.Data/Commands/TestCommand.cs b/src/Microsoft.ML.Data/Commands/TestCommand.cs index ca32da0ddd..1630d9af57 100644 --- a/src/Microsoft.ML.Data/Commands/TestCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TestCommand.cs @@ -67,7 +67,7 @@ public override void Run() using (var ch = Host.Start(command)) using (var server = InitServer(ch)) { - var settings = CmdParser.GetSettings(ch, Args, new Arguments()); + var settings = CmdParser.GetSettings(Host, Args, new Arguments()); ch.Info("maml.exe {0} {1}", command, settings); SendTelemetry(Host); diff --git a/src/Microsoft.ML.Data/Commands/TrainCommand.cs b/src/Microsoft.ML.Data/Commands/TrainCommand.cs index acc03b743f..1a90881464 100644 --- a/src/Microsoft.ML.Data/Commands/TrainCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TrainCommand.cs @@ -108,7 +108,7 @@ public override void Run() using (var ch = Host.Start(command)) using (var server = InitServer(ch)) { - var settings = CmdParser.GetSettings(ch, Args, new Arguments()); + var settings = CmdParser.GetSettings(Host, Args, new Arguments()); string cmd = string.Format("maml.exe {0} {1}", command, settings); ch.Info(cmd); diff --git a/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs b/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs index 270f7c3b03..63c8691747 100644 --- a/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs @@ -94,7 +94,7 @@ public override void Run() using (var ch = Host.Start(LoadName)) using (var server = InitServer(ch)) { - var settings = CmdParser.GetSettings(ch, Args, new Arguments()); + var settings = CmdParser.GetSettings(Host, Args, new Arguments()); string cmd = string.Format("maml.exe {0} {1}", LoadName, settings); ch.Info(cmd); diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs index 582212738a..9991956fcb 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs @@ -778,7 +778,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010003, // Number of blocks to put in the shuffle pool verReadableCur: 0x00010003, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(BinaryLoader).Assembly.FullName); } private BinaryLoader(Arguments args, IHost host, Stream stream, bool leaveOpen) diff --git a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs index cb18bdce1d..d71f6b9ac5 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs @@ -71,7 +71,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010002, // Added transform tags and args strings verReadableCur: 0x00010002, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(CompositeDataLoader).Assembly.FullName); } // The composition of loader plus transforms in order. diff --git a/src/Microsoft.ML.Data/DataLoadSave/PartitionedFileLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/PartitionedFileLoader.cs index 27ac28c717..6c52cefa56 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/PartitionedFileLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/PartitionedFileLoader.cs @@ -60,7 +60,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoadName); + loaderSignature: LoadName, + loaderAssemblyName: typeof(PartitionedFileLoader).Assembly.FullName); } public class Arguments diff --git a/src/Microsoft.ML.Data/DataLoadSave/PartitionedPathParser.cs b/src/Microsoft.ML.Data/DataLoadSave/PartitionedPathParser.cs index 70d8f898ab..c10fe09116 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/PartitionedPathParser.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/PartitionedPathParser.cs @@ -91,7 +91,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoadName); + loaderSignature: LoadName, + loaderAssemblyName: typeof(SimplePartitionedPathParser).Assembly.FullName); } private IHost _host; @@ -214,7 +215,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoadName); + loaderSignature: LoadName, + loaderAssemblyName: typeof(ParquetPartitionedPathParser).Assembly.FullName); } public ParquetPartitionedPathParser() diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs index 3663c93cc4..195befe518 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs @@ -960,7 +960,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x0001000B, // Header now retained if used and present verReadableCur: 0x0001000A, verWeCanReadBack: 0x00010009, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(TextLoader).Assembly.FullName); } /// @@ -1180,13 +1181,13 @@ private static bool TryParseSchema(IHost host, IMultiStreamSource files, goto LDone; // Make sure the loader binds to us. - var info = ComponentCatalog.GetLoadableClassInfo(loader.Name); + var info = host.ComponentCatalog.GetLoadableClassInfo(loader.Name); if (info.Type != typeof(IDataLoader) || info.ArgType != typeof(Arguments)) goto LDone; var argsNew = new Arguments(); // Copy the non-core arguments to the new args (we already know that all the core arguments are default). - var parsed = CmdParser.ParseArguments(host, CmdParser.GetSettings(ch, args, new Arguments()), argsNew); + var parsed = CmdParser.ParseArguments(host, CmdParser.GetSettings(host, args, new Arguments()), argsNew); ch.Assert(parsed); // Copy the core arguments to the new args. if (!CmdParser.ParseArguments(host, loader.GetSettingsString(), argsNew, typeof(ArgumentsCore), msg => ch.Error(msg))) diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs index 64cf7f7faf..73f512b478 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs @@ -468,7 +468,7 @@ private string CreateLoaderArguments(ISchema schema, ValueWriter[] pipes, bool h sb.Append(" col="); if (!column.TryUnparse(sb)) { - var settings = CmdParser.GetSettings(ch, column, new TextLoader.Column()); + var settings = CmdParser.GetSettings(_host, column, new TextLoader.Column()); CmdQuoter.QuoteValue(settings, sb, true); } if (type.IsVector && !type.IsKnownSizeVector && i != pipes.Length - 1) diff --git a/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs b/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs index 64130ed80e..cd4450615d 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs @@ -77,7 +77,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(TransformWrapper).Assembly.FullName); } // Factory for SignatureLoadModel. diff --git a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs index d5f246689a..31ab59a100 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs @@ -55,7 +55,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: TransformerChain.LoaderSignature); + loaderSignature: TransformerChain.LoaderSignature, + loaderAssemblyName: typeof(TransformerChain<>).Assembly.FullName); } /// diff --git a/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs index 37cbe23b92..a6c12b48ed 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs @@ -354,7 +354,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoadName); + loaderSignature: LoadName, + loaderAssemblyName: typeof(TransposeLoader).Assembly.FullName); } // We return the schema view's schema, because we don't necessarily want diff --git a/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs b/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs index ef19012bf6..799741fa1a 100644 --- a/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs +++ b/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs @@ -241,7 +241,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(RowToRowMapperTransform).Assembly.FullName); } public override ISchema Schema { get { return _bindings; } } diff --git a/src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs b/src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs index 58ba66e091..dfa90bfec6 100644 --- a/src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs +++ b/src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs @@ -222,7 +222,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(FeatureNameCollection).Assembly.FullName); } public static void Save(ModelSaveContext ctx, ref VBuffer> names) diff --git a/src/Microsoft.ML.Data/Dirty/ChooseColumnsByIndexTransform.cs b/src/Microsoft.ML.Data/Dirty/ChooseColumnsByIndexTransform.cs index da237611b9..df0e0ff049 100644 --- a/src/Microsoft.ML.Data/Dirty/ChooseColumnsByIndexTransform.cs +++ b/src/Microsoft.ML.Data/Dirty/ChooseColumnsByIndexTransform.cs @@ -189,7 +189,8 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderSignatureAlt: LoaderSignatureOld); + loaderSignatureAlt: LoaderSignatureOld, + loaderAssemblyName: typeof(ChooseColumnsByIndexTransform).Assembly.FullName); } private readonly Bindings _bindings; diff --git a/src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs b/src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs index 260ff37f26..195b2ec5f6 100644 --- a/src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs +++ b/src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs @@ -83,7 +83,7 @@ public static Output Score(IHostEnvironment env, Input input) ch.AssertValue(bindable); var mapper = bindable.Bind(host, data.Schema); - var scorer = ScoreUtils.GetScorerComponent(mapper, input.Suffix); + var scorer = ScoreUtils.GetScorerComponent(host, mapper, input.Suffix); scoredPipe = scorer.CreateComponent(host, data.Data, mapper, input.PredictorModel.GetTrainingSchema(host)); ch.Done(); } @@ -134,7 +134,7 @@ public static Output MakeScoringTransform(IHostEnvironment env, ModelInput input ch.AssertValue(bindable); var mapper = bindable.Bind(host, data.Schema); - var scorer = ScoreUtils.GetScorerComponent(mapper); + var scorer = ScoreUtils.GetScorerComponent(host, mapper); scoredPipe = scorer.CreateComponent(host, data.Data, mapper, input.PredictorModel.GetTrainingSchema(host)); ch.Done(); } diff --git a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs index a92526dfa9..2247c32637 100644 --- a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs @@ -1034,7 +1034,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(BinaryPerInstanceEvaluator).Assembly.FullName); } private const int AssignedCol = 0; diff --git a/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs index fbbafa8775..2c476eb468 100644 --- a/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs @@ -529,7 +529,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(ClusteringPerInstanceEvaluator).Assembly.FullName); } private const int ClusterIdCol = 0; diff --git a/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs index 84fd29ea50..c264d438b0 100644 --- a/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs @@ -385,7 +385,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(MultiOutputRegressionPerInstanceEvaluator).Assembly.FullName); } private const int LabelOutput = 0; diff --git a/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs index 900df4ac53..205d74a526 100644 --- a/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs @@ -671,7 +671,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010002, // Serialize the class names verReadableCur: 0x00010002, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(MultiClassPerInstanceEvaluator).Assembly.FullName); } private const int AssignedCol = 0; diff --git a/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs index d4c70c6eac..9a2f22de5d 100644 --- a/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs @@ -262,7 +262,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(QuantileRegressionPerInstanceEvaluator).Assembly.FullName); } private const int L1Col = 0; diff --git a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs index 131f16b058..d071a906e3 100644 --- a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs @@ -523,7 +523,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(RankerPerInstanceTransform).Assembly.FullName); } public const string Ndcg = "NDCG"; diff --git a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs index b20492d6dc..203050d774 100644 --- a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs @@ -293,7 +293,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(RegressionPerInstanceEvaluator).Assembly.FullName); } private const int L1Col = 0; diff --git a/src/Microsoft.ML.Data/Model/ModelHeader.cs b/src/Microsoft.ML.Data/Model/ModelHeader.cs index 067ff12285..37a5b9ac92 100644 --- a/src/Microsoft.ML.Data/Model/ModelHeader.cs +++ b/src/Microsoft.ML.Data/Model/ModelHeader.cs @@ -20,11 +20,14 @@ public struct ModelHeader public const ulong SignatureValue = 0x4C45444F4D004C4DUL; public const ulong TailSignatureValue = 0x4D4C004D4F44454CUL; + private const uint VerAssemblyNameSupported = 0x00010002; + // These are private since they change over time. If we make them public we risk // another assembly containing a "copy" of their value when the other assembly // was compiled, which might not match the code that can load this. - private const uint VerWrittenCur = 0x00010001; - private const uint VerReadableCur = 0x00010001; + //private const uint VerWrittenCur = 0x00010001; // Initial + private const uint VerWrittenCur = 0x00010002; // Added AssemblyName + private const uint VerReadableCur = 0x00010002; private const uint VerWeCanReadBack = 0x00010001; [FieldOffset(0x00)] @@ -85,7 +88,13 @@ public struct ModelHeader [FieldOffset(0x88)] public long FpLim; - // Lots of padding.... + // Location of the fully qualified assembly name string (in UTF-16). + // Note that it is legal for both to be zero. + [FieldOffset(0x90)] + public long FpAssemblyName; + [FieldOffset(0x98)] + public uint CbAssemblyName; + public const int Size = 0x0100; // Utilities for writing. @@ -116,7 +125,7 @@ public static void BeginWrite(BinaryWriter writer, out long fpMin, out ModelHead /// The current writer position should be the end of the model blob. Records the model size, writes the string table, /// completes and writes the header, and writes the tail. /// - public static void EndWrite(BinaryWriter writer, long fpMin, ref ModelHeader header, NormStr.Pool pool = null) + public static void EndWrite(BinaryWriter writer, long fpMin, ref ModelHeader header, NormStr.Pool pool = null, string loaderAssemblyName = null) { Contracts.CheckValue(writer, nameof(writer)); Contracts.CheckParam(fpMin >= 0, nameof(fpMin)); @@ -157,9 +166,28 @@ public static void EndWrite(BinaryWriter writer, long fpMin, ref ModelHeader hea Contracts.Assert(offset == header.CbStringChars); } + WriteLoaderAssemblyName(writer, fpMin, ref header, loaderAssemblyName); + WriteHeaderAndTailCore(writer, fpMin, ref header); } + private static void WriteLoaderAssemblyName(BinaryWriter writer, long fpMin, ref ModelHeader header, string loaderAssemblyName) + { + if (!string.IsNullOrEmpty(loaderAssemblyName)) + { + header.FpAssemblyName = writer.FpCur() - fpMin; + header.CbAssemblyName = (uint)loaderAssemblyName.Length * sizeof(char); + + foreach (var ch in loaderAssemblyName) + writer.Write((short)ch); + } + else + { + header.FpAssemblyName = 0; + header.CbAssemblyName = 0; + } + } + /// /// The current writer position should be where the tail belongs. Writes the header and tail. /// Typically this isn't called directly unless you are doing custom string table serialization. @@ -289,7 +317,7 @@ public static void MarshalToBytes(ref ModelHeader header, byte[] bytes) /// Read the model header, strings, etc from reader. Also validates the header (throws if bad). /// Leaves the reader position at the beginning of the model blob. /// - public static void BeginRead(out long fpMin, out ModelHeader header, out string[] strings, BinaryReader reader) + public static void BeginRead(out long fpMin, out ModelHeader header, out string[] strings, out string loaderAssemblyName, BinaryReader reader) { fpMin = reader.FpCur(); @@ -298,7 +326,7 @@ public static void BeginRead(out long fpMin, out ModelHeader header, out string[ ModelHeader.MarshalFromBytes(out header, headerBytes); Exception ex; - if (!ModelHeader.TryValidate(ref header, reader, fpMin, out strings, out ex)) + if (!ModelHeader.TryValidate(ref header, reader, fpMin, out strings, out loaderAssemblyName, out ex)) throw ex; reader.Seek(header.FpModel + fpMin); @@ -375,7 +403,10 @@ public static bool TryValidate(ref ModelHeader header, long size, out Exception Contracts.CheckDecode(header.CbStringTable == 0); Contracts.CheckDecode(header.FpStringChars == 0); Contracts.CheckDecode(header.CbStringChars == 0); - Contracts.CheckDecode(header.FpTail == header.FpModel + header.CbModel); + if (header.VerWritten < VerAssemblyNameSupported || header.FpAssemblyName == 0) + { + Contracts.CheckDecode(header.FpTail == header.FpModel + header.CbModel); + } } else { @@ -387,8 +418,34 @@ public static bool TryValidate(ref ModelHeader header, long size, out Exception Contracts.CheckDecode(header.FpStringChars == header.FpStringTable + header.CbStringTable); Contracts.CheckDecode(header.CbStringChars % sizeof(char) == 0); Contracts.CheckDecode(header.FpStringChars + header.CbStringChars >= header.FpStringChars); - Contracts.CheckDecode(header.FpTail == header.FpStringChars + header.CbStringChars); + if (header.VerWritten < VerAssemblyNameSupported || header.FpAssemblyName == 0) + { + Contracts.CheckDecode(header.FpTail == header.FpStringChars + header.CbStringChars); + } } + + if (header.VerWritten >= VerAssemblyNameSupported) + { + if (header.FpAssemblyName == 0) + { + Contracts.CheckDecode(header.CbAssemblyName == 0); + } + else + { + // the assembly name always immediately after the string table, if there is one + if (header.FpStringTable == 0) + { + Contracts.CheckDecode(header.FpAssemblyName == header.FpModel + header.CbModel); + } + else + { + Contracts.CheckDecode(header.FpAssemblyName == header.FpStringChars + header.CbStringChars); + } + Contracts.CheckDecode(header.CbAssemblyName % sizeof(char) == 0); + Contracts.CheckDecode(header.FpTail == header.FpAssemblyName + header.CbAssemblyName); + } + } + Contracts.CheckDecode(header.FpLim == header.FpTail + sizeof(ulong)); Contracts.CheckDecode(size == 0 || size >= header.FpLim); @@ -405,7 +462,7 @@ public static bool TryValidate(ref ModelHeader header, long size, out Exception /// /// Checks the validity of the header, reads the string table, etc. /// - public static bool TryValidate(ref ModelHeader header, BinaryReader reader, long fpMin, out string[] strings, out Exception ex) + public static bool TryValidate(ref ModelHeader header, BinaryReader reader, long fpMin, out string[] strings, out string loaderAssemblyName, out Exception ex) { Contracts.CheckValue(reader, nameof(reader)); Contracts.Check(fpMin >= 0); @@ -413,62 +470,85 @@ public static bool TryValidate(ref ModelHeader header, BinaryReader reader, long if (!TryValidate(ref header, reader.BaseStream.Length - fpMin, out ex)) { strings = null; + loaderAssemblyName = null; return false; } - if (header.FpStringTable == 0) - { - // No strings. - strings = null; - ex = null; - return true; - } - try { long fpOrig = reader.FpCur(); - reader.Seek(header.FpStringTable + fpMin); - Contracts.Assert(reader.FpCur() == header.FpStringTable + fpMin); - long cstr = header.CbStringTable / sizeof(long); - Contracts.Assert(cstr < int.MaxValue); - long[] offsets = reader.ReadLongArray((int)cstr); - Contracts.Assert(header.FpStringChars == reader.FpCur() - fpMin); - Contracts.CheckDecode(offsets[cstr - 1] == header.CbStringChars); + StringBuilder sb = null; + if (header.FpStringTable == 0) + { + // No strings. + strings = null; + } + else + { + reader.Seek(header.FpStringTable + fpMin); + Contracts.Assert(reader.FpCur() == header.FpStringTable + fpMin); + + long cstr = header.CbStringTable / sizeof(long); + Contracts.Assert(cstr < int.MaxValue); + long[] offsets = reader.ReadLongArray((int)cstr); + Contracts.Assert(header.FpStringChars == reader.FpCur() - fpMin); + Contracts.CheckDecode(offsets[cstr - 1] == header.CbStringChars); + + strings = new string[cstr]; + long offset = 0; + sb = new StringBuilder(); + for (int i = 0; i < offsets.Length; i++) + { + Contracts.CheckDecode(header.FpStringChars + offset == reader.FpCur() - fpMin); + + long offsetPrev = offset; + offset = offsets[i]; + Contracts.CheckDecode(offsetPrev <= offset & offset <= header.CbStringChars); + Contracts.CheckDecode(offset % sizeof(char) == 0); + long cch = (offset - offsetPrev) / sizeof(char); + Contracts.CheckDecode(cch < int.MaxValue); + + sb.Clear(); + for (long ich = 0; ich < cch; ich++) + sb.Append((char)reader.ReadUInt16()); + strings[i] = sb.ToString(); + } + Contracts.CheckDecode(offset == header.CbStringChars); + Contracts.CheckDecode(header.FpStringChars + header.CbStringChars == reader.FpCur() - fpMin); + } - strings = new string[cstr]; - long offset = 0; - var sb = new StringBuilder(); - for (int i = 0; i < offsets.Length; i++) + if (header.VerWritten >= VerAssemblyNameSupported && header.FpAssemblyName != 0) { - Contracts.CheckDecode(header.FpStringChars + offset == reader.FpCur() - fpMin); + reader.Seek(header.FpAssemblyName + fpMin); + int assemblyNameLength = (int)header.CbAssemblyName / sizeof(char); - long offsetPrev = offset; - offset = offsets[i]; - Contracts.CheckDecode(offsetPrev <= offset & offset <= header.CbStringChars); - Contracts.CheckDecode(offset % sizeof(char) == 0); - long cch = (offset - offsetPrev) / sizeof(char); - Contracts.CheckDecode(cch < int.MaxValue); + sb = sb != null ? sb.Clear() : new StringBuilder(assemblyNameLength); - sb.Clear(); - for (long ich = 0; ich < cch; ich++) + for (long ich = 0; ich < assemblyNameLength; ich++) sb.Append((char)reader.ReadUInt16()); - strings[i] = sb.ToString(); + + loaderAssemblyName = sb.ToString(); } - Contracts.CheckDecode(offset == header.CbStringChars); - Contracts.CheckDecode(header.FpStringChars + header.CbStringChars == reader.FpCur() - fpMin); + else + { + loaderAssemblyName = null; + } + Contracts.CheckDecode(header.FpTail == reader.FpCur() - fpMin); ulong tail = reader.ReadUInt64(); Contracts.CheckDecode(tail == TailSignatureValue, "Corrupt model file tail"); - reader.Seek(fpOrig); ex = null; + + reader.Seek(fpOrig); return true; } catch (Exception e) { strings = null; + loaderAssemblyName = null; ex = e; return false; } @@ -529,12 +609,13 @@ public static string GetLoaderSigAlt(ref ModelHeader header) /// This is used to simplify version checking boiler-plate code. It is an optional /// utility type. /// - public struct VersionInfo + public readonly struct VersionInfo { public readonly ulong ModelSignature; public readonly uint VerWrittenCur; public readonly uint VerReadableCur; public readonly uint VerWeCanReadBack; + public readonly string LoaderAssemblyName; public readonly string LoaderSignature; public readonly string LoaderSignatureAlt; @@ -543,7 +624,7 @@ public struct VersionInfo /// all less than 0x100. Spaces are mapped to zero. This assumes little-endian. /// public VersionInfo(string modelSignature, uint verWrittenCur, uint verReadableCur, uint verWeCanReadBack, - string loaderSignature = null, string loaderSignatureAlt = null) + string loaderAssemblyName, string loaderSignature = null, string loaderSignatureAlt = null) { Contracts.Check(Utils.Size(modelSignature) == 8, "Model signature must be eight characters"); ModelSignature = 0; @@ -559,17 +640,7 @@ public VersionInfo(string modelSignature, uint verWrittenCur, uint verReadableCu VerWrittenCur = verWrittenCur; VerReadableCur = verReadableCur; VerWeCanReadBack = verWeCanReadBack; - LoaderSignature = loaderSignature; - LoaderSignatureAlt = loaderSignatureAlt; - } - - public VersionInfo(ulong modelSignature, uint verWrittenCur, uint verReadableCur, uint verWeCanReadBack, - string loaderSignature = null, string loaderSignatureAlt = null) - { - ModelSignature = modelSignature; - VerWrittenCur = verWrittenCur; - VerReadableCur = verReadableCur; - VerWeCanReadBack = verWeCanReadBack; + LoaderAssemblyName = loaderAssemblyName; LoaderSignature = loaderSignature; LoaderSignatureAlt = loaderSignatureAlt; } diff --git a/src/Microsoft.ML.Data/Model/ModelLoadContext.cs b/src/Microsoft.ML.Data/Model/ModelLoadContext.cs index c09b7b09a2..7efd9f4b47 100644 --- a/src/Microsoft.ML.Data/Model/ModelLoadContext.cs +++ b/src/Microsoft.ML.Data/Model/ModelLoadContext.cs @@ -39,6 +39,14 @@ public sealed partial class ModelLoadContext : IDisposable /// public readonly string[] Strings; + /// + /// The name of the assembly that the loader lives in. + /// + /// + /// This may be null or empty if one was never written to the model, or is an older model version. + /// + public readonly string LoaderAssemblyName; + /// /// The main stream's model header. /// @@ -76,7 +84,7 @@ public ModelLoadContext(RepositoryReader rep, Repository.Entry ent, string dir) Reader = new BinaryReader(ent.Stream, Encoding.UTF8, leaveOpen: true); try { - ModelHeader.BeginRead(out FpMin, out Header, out Strings, Reader); + ModelHeader.BeginRead(out FpMin, out Header, out Strings, out LoaderAssemblyName, Reader); } catch { @@ -97,7 +105,7 @@ public ModelLoadContext(BinaryReader reader, IExceptionContext ectx = null) Repository = null; Directory = null; Reader = reader; - ModelHeader.BeginRead(out FpMin, out Header, out Strings, Reader); + ModelHeader.BeginRead(out FpMin, out Header, out Strings, out LoaderAssemblyName, Reader); } public void CheckAtModel() diff --git a/src/Microsoft.ML.Data/Model/ModelLoading.cs b/src/Microsoft.ML.Data/Model/ModelLoading.cs index c69461ea79..981da9e797 100644 --- a/src/Microsoft.ML.Data/Model/ModelLoading.cs +++ b/src/Microsoft.ML.Data/Model/ModelLoading.cs @@ -4,6 +4,7 @@ using System; using System.IO; +using System.Reflection; using System.Text; using Microsoft.ML.Runtime.Internal.Utilities; @@ -212,6 +213,8 @@ private bool TryLoadModelCore(IHostEnvironment env, out TRes result, var args = ConcatArgsRev(extra, this); + EnsureLoaderAssemblyIsRegistered(env.ComponentCatalog); + object tmp; string sig = ModelHeader.GetLoaderSig(ref Header); if (!string.IsNullOrWhiteSpace(sig) && @@ -246,6 +249,15 @@ private bool TryLoadModelCore(IHostEnvironment env, out TRes result, return false; } + private void EnsureLoaderAssemblyIsRegistered(ComponentCatalog catalog) + { + if (!string.IsNullOrEmpty(LoaderAssemblyName)) + { + var assembly = Assembly.Load(LoaderAssemblyName); + catalog.RegisterAssembly(assembly); + } + } + private static object[] ConcatArgsRev(object[] args2, params object[] args1) { Contracts.AssertNonEmpty(args1); diff --git a/src/Microsoft.ML.Data/Model/ModelSaveContext.cs b/src/Microsoft.ML.Data/Model/ModelSaveContext.cs index 5e32893e7d..c5a9199758 100644 --- a/src/Microsoft.ML.Data/Model/ModelSaveContext.cs +++ b/src/Microsoft.ML.Data/Model/ModelSaveContext.cs @@ -24,7 +24,7 @@ public sealed partial class ModelSaveContext : IDisposable public readonly RepositoryWriter Repository; /// - /// When in repository mode, this is the direcory we're reading from. Null means the root + /// When in repository mode, this is the directory we're reading from. Null means the root /// of the repository. It is always null in single-stream mode. /// public readonly string Directory; @@ -59,6 +59,11 @@ public sealed partial class ModelSaveContext : IDisposable /// private readonly IExceptionContext _ectx; + /// + /// The assembly name where the loader resides. + /// + private string _loaderAssemblyName; + /// /// Returns whether this context is in repository mode (true) or single-stream mode (false). /// @@ -131,6 +136,7 @@ public void CheckAtModel() public void SetVersionInfo(VersionInfo ver) { ModelHeader.SetVersionInfo(ref Header, ver); + _loaderAssemblyName = ver.LoaderAssemblyName; } public void SaveTextStream(string name, Action action) @@ -212,7 +218,7 @@ public void SaveNonEmptyString(ReadOnlyMemory str) public void Done() { _ectx.Check(Header.ModelSignature != 0, "ModelSignature not specified!"); - ModelHeader.EndWrite(Writer, FpMin, ref Header, Strings); + ModelHeader.EndWrite(Writer, FpMin, ref Header, Strings, _loaderAssemblyName); Dispose(); } diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index 6cee31b9d2..310774a8a5 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -332,7 +332,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(CalibratedPredictor).Assembly.FullName); } private static VersionInfo GetVersionInfoBulk() { @@ -341,7 +342,8 @@ private static VersionInfo GetVersionInfoBulk() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(CalibratedPredictor).Assembly.FullName); } private CalibratedPredictor(IHostEnvironment env, ModelLoadContext ctx) @@ -393,7 +395,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(FeatureWeightsCalibratedPredictor).Assembly.FullName); } private FeatureWeightsCalibratedPredictor(IHostEnvironment env, ModelLoadContext ctx) @@ -455,7 +458,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(ParameterMixingCalibratedPredictor).Assembly.FullName); } private ParameterMixingCalibratedPredictor(IHostEnvironment env, ModelLoadContext ctx) @@ -608,7 +612,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(SchemaBindableCalibratedPredictor).Assembly.FullName); } /// @@ -967,7 +972,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(NaiveCalibrator).Assembly.FullName); } private readonly IHost _host; @@ -1334,7 +1340,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(PlattCalibrator).Assembly.FullName); } private readonly IHost _host; @@ -1569,7 +1576,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(PavCalibrator).Assembly.FullName); } // Epsilon for 0-comparisons diff --git a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs index ee76b5a674..dc9a5bbdf9 100644 --- a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs @@ -40,7 +40,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010004, // ISchemaBindableMapper update verReadableCur: 0x00010004, verWeCanReadBack: 0x00010004, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(BinaryClassifierScorer).Assembly.FullName); } private const string RegistrationName = "BinaryClassifierScore"; diff --git a/src/Microsoft.ML.Data/Scorers/ClusteringScorer.cs b/src/Microsoft.ML.Data/Scorers/ClusteringScorer.cs index 7a065fe4d4..26eae0154d 100644 --- a/src/Microsoft.ML.Data/Scorers/ClusteringScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/ClusteringScorer.cs @@ -37,7 +37,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010003, // ISchemaBindableMapper update verReadableCur: 0x00010003, verWeCanReadBack: 0x00010003, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(ClusteringScorer).Assembly.FullName); } private const string RegistrationName = "ClusteringScore"; diff --git a/src/Microsoft.ML.Data/Scorers/GenericScorer.cs b/src/Microsoft.ML.Data/Scorers/GenericScorer.cs index 41c12e94ed..0b261502ed 100644 --- a/src/Microsoft.ML.Data/Scorers/GenericScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/GenericScorer.cs @@ -130,7 +130,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(GenericScorer).Assembly.FullName); } private const string RegistrationName = "GenericScore"; diff --git a/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs index c12fd9b4d1..79387c2727 100644 --- a/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs @@ -48,7 +48,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010003, // ISchemaBindableMapper update verReadableCur: 0x00010003, verWeCanReadBack: 0x00010003, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(MultiClassClassifierScorer).Assembly.FullName); } private const string RegistrationName = "MultiClassClassifierScore"; @@ -88,7 +89,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010002, // Added metadataKind verReadableCur: 0x00010002, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(LabelNameBindableMapper).Assembly.FullName); } private const int VersionAddedMetadataKind = 0x00010002; diff --git a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs index aa71ac2dc5..904b591dd4 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs @@ -256,7 +256,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: BinaryPredictionTransformer.LoaderSignature); + loaderSignature: BinaryPredictionTransformer.LoaderSignature, + loaderAssemblyName: typeof(BinaryPredictionTransformer<>).Assembly.FullName); } } @@ -321,7 +322,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: MulticlassPredictionTransformer.LoaderSignature); + loaderSignature: MulticlassPredictionTransformer.LoaderSignature, + loaderAssemblyName: typeof(MulticlassPredictionTransformer<>).Assembly.FullName); } } @@ -371,7 +373,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: RegressionPredictionTransformer.LoaderSignature); + loaderSignature: RegressionPredictionTransformer.LoaderSignature, + loaderAssemblyName: typeof(RegressionPredictionTransformer<>).Assembly.FullName); } } @@ -417,7 +420,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: RankingPredictionTransformer.LoaderSignature); + loaderSignature: RankingPredictionTransformer.LoaderSignature, + loaderAssemblyName: typeof(RankingPredictionTransformer<>).Assembly.FullName); } } diff --git a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs index 31b921a064..8422d5042c 100644 --- a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs +++ b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs @@ -242,7 +242,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010002, // ISchemaBindableWrapper update verReadableCur: 0x00010002, verWeCanReadBack: 0x00010002, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(SchemaBindablePredictorWrapper).Assembly.FullName); } private readonly string _scoreColumnKind; @@ -353,7 +354,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010002, // ISchemaBindableWrapper update verReadableCur: 0x00010002, verWeCanReadBack: 0x00010002, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(SchemaBindableBinaryPredictorWrapper).Assembly.FullName); } private readonly IValueMapperDist _distMapper; @@ -581,7 +583,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010002, // ISchemaBindableWrapper update verReadableCur: 0x00010002, verWeCanReadBack: 0x00010002, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(SchemaBindableQuantileRegressionPredictor).Assembly.FullName); } private readonly IQuantileValueMapper _qpred; diff --git a/src/Microsoft.ML.Data/Transforms/ChooseColumnsTransform.cs b/src/Microsoft.ML.Data/Transforms/ChooseColumnsTransform.cs index 1459f55cab..e1300e38cd 100644 --- a/src/Microsoft.ML.Data/Transforms/ChooseColumnsTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/ChooseColumnsTransform.cs @@ -449,7 +449,8 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderSignatureAlt: LoaderSignatureOld); + loaderSignatureAlt: LoaderSignatureOld, + loaderAssemblyName: typeof(ChooseColumnsTransform).Assembly.FullName); } private readonly Bindings _bindings; diff --git a/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs b/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs index d827947192..1ec6c13d61 100644 --- a/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs @@ -243,7 +243,8 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010002, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderSignatureAlt: LoaderSignatureOld); + loaderSignatureAlt: LoaderSignatureOld, + loaderAssemblyName: typeof(ConcatTransform).Assembly.FullName); } private const int VersionAddedAliases = 0x00010002; diff --git a/src/Microsoft.ML.Data/Transforms/ConvertTransform.cs b/src/Microsoft.ML.Data/Transforms/ConvertTransform.cs index eec59c8a8f..62906ffa55 100644 --- a/src/Microsoft.ML.Data/Transforms/ConvertTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/ConvertTransform.cs @@ -161,7 +161,8 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010002, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderSignatureAlt: LoaderSignatureOld); + loaderSignatureAlt: LoaderSignatureOld, + loaderAssemblyName: typeof(ConvertTransform).Assembly.FullName); } private const string RegistrationName = "Convert"; diff --git a/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs b/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs index 478d509c24..8ca1260ebf 100644 --- a/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs @@ -93,7 +93,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(CopyColumnsTransform).Assembly.FullName); } public CopyColumnsTransform(IHostEnvironment env, params (string source, string name)[] columns) @@ -224,7 +225,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(CopyColumnsRowMapper).Assembly.FullName); } // Factory method for SignatureLoadRowMapper. diff --git a/src/Microsoft.ML.Data/Transforms/DropColumnsTransform.cs b/src/Microsoft.ML.Data/Transforms/DropColumnsTransform.cs index 3e15199ff7..aa5026a914 100644 --- a/src/Microsoft.ML.Data/Transforms/DropColumnsTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/DropColumnsTransform.cs @@ -229,7 +229,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010002, // Added KeepColumns verReadableCur: 0x00010002, verWeCanReadBack: 0x00010002, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(DropColumnsTransform).Assembly.FullName); } private readonly Bindings _bindings; diff --git a/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs b/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs index 9945ee6cc3..6f8185b75a 100644 --- a/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs @@ -188,7 +188,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(DropSlotsTransform).Assembly.FullName); } private const string RegistrationName = "DropSlots"; diff --git a/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs b/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs index c141bb66c1..11adab9ab6 100644 --- a/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs @@ -249,7 +249,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(GenerateNumberTransform).Assembly.FullName); } private readonly Bindings _bindings; diff --git a/src/Microsoft.ML.Data/Transforms/HashTransform.cs b/src/Microsoft.ML.Data/Transforms/HashTransform.cs index a589121be6..6ad7f68cbf 100644 --- a/src/Microsoft.ML.Data/Transforms/HashTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/HashTransform.cs @@ -196,7 +196,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010002, // Invert hash key values, hash fix verReadableCur: 0x00010002, verWeCanReadBack: 0x00010002, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(HashTransformer).Assembly.FullName); } private readonly ColumnInfo[] _columns; diff --git a/src/Microsoft.ML.Data/Transforms/InvertHashUtils.cs b/src/Microsoft.ML.Data/Transforms/InvertHashUtils.cs index db6ff54298..4c90bc1d9d 100644 --- a/src/Microsoft.ML.Data/Transforms/InvertHashUtils.cs +++ b/src/Microsoft.ML.Data/Transforms/InvertHashUtils.cs @@ -329,7 +329,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(TextModelHelper).Assembly.FullName); } private static void Load(IChannel ch, ModelLoadContext ctx, CodecFactory factory, ref VBuffer> values) diff --git a/src/Microsoft.ML.Data/Transforms/KeyToValueTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToValueTransform.cs index de8463a804..d9e75fb178 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToValueTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToValueTransform.cs @@ -76,7 +76,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(KeyToValueTransform).Assembly.FullName); } /// diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs index bda8729d4a..fd177d6a14 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs @@ -143,7 +143,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010002, // Get rid of writing float size in model context verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(KeyToVectorTransform).Assembly.FullName); } public override void Save(ModelSaveContext ctx) diff --git a/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs b/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs index 8817833f40..b5943b14fa 100644 --- a/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs @@ -58,7 +58,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial. verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(LabelConvertTransform).Assembly.FullName); } private const string RegistrationName = "LabelConvert"; diff --git a/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs b/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs index 81882ac749..5844d61d31 100644 --- a/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs @@ -39,7 +39,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(LabelIndicatorTransform).Assembly.FullName); } public sealed class Column : OneToOneColumn diff --git a/src/Microsoft.ML.Data/Transforms/NAFilter.cs b/src/Microsoft.ML.Data/Transforms/NAFilter.cs index c8515291f3..ed738f3e63 100644 --- a/src/Microsoft.ML.Data/Transforms/NAFilter.cs +++ b/src/Microsoft.ML.Data/Transforms/NAFilter.cs @@ -70,7 +70,8 @@ private static VersionInfo GetVersionInfo() loaderSignature: LoaderSignature, // This is an older name and can be removed once we don't care about old code // being able to load this. - loaderSignatureAlt: "MissingFeatureFilter"); + loaderSignatureAlt: "MissingFeatureFilter", + loaderAssemblyName: typeof(NAFilter).Assembly.FullName); } private readonly ColInfo[] _infos; diff --git a/src/Microsoft.ML.Data/Transforms/NopTransform.cs b/src/Microsoft.ML.Data/Transforms/NopTransform.cs index 0a71a3ec07..a367073f2f 100644 --- a/src/Microsoft.ML.Data/Transforms/NopTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/NopTransform.cs @@ -55,7 +55,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(NopTransform).Assembly.FullName); } internal static string RegistrationName = "NopTransform"; diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs index b515b2293c..a91c4645eb 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs @@ -615,7 +615,8 @@ public static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(CdfColumnFunction).Assembly.FullName); } } @@ -677,7 +678,8 @@ public static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(BinColumnFunction).Assembly.FullName); } } @@ -1139,7 +1141,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010003, // Scales multiply instead of divide verReadableCur: 0x00010003, verWeCanReadBack: 0x00010003, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(AffineNormSerializationUtils).Assembly.FullName); } } @@ -1154,7 +1157,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(BinNormSerializationUtils).Assembly.FullName); } } diff --git a/src/Microsoft.ML.Data/Transforms/Normalizer.cs b/src/Microsoft.ML.Data/Transforms/Normalizer.cs index 7cf6bd3d31..4c3d9fd51d 100644 --- a/src/Microsoft.ML.Data/Transforms/Normalizer.cs +++ b/src/Microsoft.ML.Data/Transforms/Normalizer.cs @@ -216,7 +216,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(NormalizerTransformer).Assembly.FullName); } private class ColumnInfo diff --git a/src/Microsoft.ML.Data/Transforms/RangeFilter.cs b/src/Microsoft.ML.Data/Transforms/RangeFilter.cs index 142779dee2..d20446f7f5 100644 --- a/src/Microsoft.ML.Data/Transforms/RangeFilter.cs +++ b/src/Microsoft.ML.Data/Transforms/RangeFilter.cs @@ -64,7 +64,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(RangeFilter).Assembly.FullName); } private const string RegistrationName = "RangeFilter"; diff --git a/src/Microsoft.ML.Data/Transforms/ShuffleTransform.cs b/src/Microsoft.ML.Data/Transforms/ShuffleTransform.cs index 3940bbe979..4043e3c42b 100644 --- a/src/Microsoft.ML.Data/Transforms/ShuffleTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/ShuffleTransform.cs @@ -71,7 +71,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010002, // Force shuffle source saving verReadableCur: 0x00010002, verWeCanReadBack: 0x00010002, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(ShuffleTransform).Assembly.FullName); } private const string RegistrationName = "Shuffle"; diff --git a/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs b/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs index 2adb17258e..440d68ae84 100644 --- a/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs +++ b/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs @@ -76,7 +76,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(SkipTakeFilter).Assembly.FullName); } private readonly long _skip; diff --git a/src/Microsoft.ML.Data/Transforms/TermTransform.cs b/src/Microsoft.ML.Data/Transforms/TermTransform.cs index 55db806941..c3084a9fa7 100644 --- a/src/Microsoft.ML.Data/Transforms/TermTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/TermTransform.cs @@ -193,7 +193,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010003, // Generalize to multiple types beyond text verReadableCur: 0x00010003, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(TermTransform).Assembly.FullName); } private const uint VerNonTextTypesSupported = 0x00010003; @@ -224,7 +225,8 @@ private static VersionInfo GetTermManagerVersionInfo() verWrittenCur: 0x00010002, // Generalize to multiple types beyond text verReadableCur: 0x00010002, verWeCanReadBack: 0x00010001, - loaderSignature: TermManagerLoaderSignature); + loaderSignature: TermManagerLoaderSignature, + loaderAssemblyName: typeof(TermTransform).Assembly.FullName); } private readonly TermMap[] _unboundMaps; diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Average.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Average.cs index 45cd764d13..de1e5ef505 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/Average.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Average.cs @@ -25,7 +25,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(Average).Assembly.FullName); } public Average(IHostEnvironment env) diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Median.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Median.cs index de8d950de4..95bc0cc991 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/Median.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Median.cs @@ -30,7 +30,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(Median).Assembly.FullName); } public Median(IHostEnvironment env) diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiAverage.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiAverage.cs index c147f932f3..fef6fa087e 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiAverage.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiAverage.cs @@ -28,7 +28,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(MultiAverage).Assembly.FullName); } [TlcModule.Component(Name = LoadName, FriendlyName = Average.UserName)] diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiMedian.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiMedian.cs index c3e6869d69..86312393de 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiMedian.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiMedian.cs @@ -31,7 +31,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(MultiMedian).Assembly.FullName); } [TlcModule.Component(Name = LoadName, FriendlyName = Median.UserName)] diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs index 67af075f7d..fd6af4dc53 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs @@ -35,7 +35,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(MultiStacking).Assembly.FullName); } [TlcModule.Component(Name = LoadName, FriendlyName = Stacking.UserName)] diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs index ee55b94c77..0c68f70287 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs @@ -29,7 +29,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(MultiVoting).Assembly.FullName); } private sealed class Arguments : ArgumentsBase diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiWeightedAverage.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiWeightedAverage.cs index 9bda1d151a..a2f52b1451 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiWeightedAverage.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiWeightedAverage.cs @@ -35,7 +35,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(MultiWeightedAverage).Assembly.FullName); } [TlcModule.Component(Name = LoadName, FriendlyName = UserName)] diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs index fa68546eb8..f1503ed23d 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs @@ -33,7 +33,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(RegressionStacking).Assembly.FullName); } [TlcModule.Component(Name = LoadName, FriendlyName = Stacking.UserName)] diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs index 63adbd7c56..9a569cb2af 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs @@ -31,7 +31,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(Stacking).Assembly.FullName); } [TlcModule.Component(Name = LoadName, FriendlyName = UserName)] diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Voting.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Voting.cs index 932f99d93a..d352439d55 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/Voting.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Voting.cs @@ -27,7 +27,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(Voting).Assembly.FullName); } public Voting(IHostEnvironment env) diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/WeightedAverage.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/WeightedAverage.cs index 8b16ffd0a2..e11f0c9bb9 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/WeightedAverage.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/WeightedAverage.cs @@ -32,7 +32,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(WeightedAverage).Assembly.FullName); } [TlcModule.Component(Name = LoadName, FriendlyName = UserName)] diff --git a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs index e4d6cb0d9e..a13903f1a3 100644 --- a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs +++ b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs @@ -377,7 +377,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010002, // Save predictor models in a subdirectory verReadableCur: 0x00010002, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(SchemaBindablePipelineEnsembleBase).Assembly.FullName); } public const string UserName = "Pipeline Ensemble"; public const string LoaderSignature = "PipelineEnsemble"; diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionPredictor.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionPredictor.cs index 547800c152..a15563a289 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionPredictor.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionPredictor.cs @@ -37,7 +37,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010003, // Don't serialize the "IsAveraged" property of the metrics verReadableCur: 0x00010003, verWeCanReadBack: 0x00010002, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(EnsembleDistributionPredictor).Assembly.FullName); } private readonly Single[] _averagedWeights; diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictor.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictor.cs index 08c8f0dd8d..1f2ed87bf6 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictor.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictor.cs @@ -35,7 +35,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010003, // Don't serialize the "IsAveraged" property of the metrics verReadableCur: 0x00010003, verWeCanReadBack: 0x00010002, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(EnsemblePredictor).Assembly.FullName); } private readonly IValueMapper[] _mappers; diff --git a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassPredictor.cs b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassPredictor.cs index 558d0afd6e..4dfaf3983a 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassPredictor.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassPredictor.cs @@ -32,7 +32,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010003, // Don't serialize the "IsAveraged" property of the metrics verReadableCur: 0x00010003, verWeCanReadBack: 0x00010002, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(EnsembleMultiClassPredictor).Assembly.FullName); } private readonly ColumnType _inputType; diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 0d9d5bc192..ca103f3c25 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -819,7 +819,7 @@ protected virtual void PrintPrologInfo(IChannel ch) { Contracts.AssertValue(ch); ch.Trace("Host = {0}", Environment.MachineName); - ch.Trace("CommandLine = {0}", CmdParser.GetSettings(ch, Args, new TArgs())); + ch.Trace("CommandLine = {0}", CmdParser.GetSettings(Host, Args, new TArgs())); ch.Trace("GCSettings.IsServerGC = {0}", System.Runtime.GCSettings.IsServerGC); ch.Trace("{0}", Args); } diff --git a/src/Microsoft.ML.FastTree/FastTreeClassification.cs b/src/Microsoft.ML.FastTree/FastTreeClassification.cs index 97cc1c43f9..8537555ec9 100644 --- a/src/Microsoft.ML.FastTree/FastTreeClassification.cs +++ b/src/Microsoft.ML.FastTree/FastTreeClassification.cs @@ -58,7 +58,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010005, //Categorical splits. verReadableCur: 0x00010005, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(FastTreeBinaryPredictor).Assembly.FullName); } protected override uint VerNumFeaturesSerialized => 0x00010002; diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs index 90331c47b4..0c8419846a 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs @@ -1100,7 +1100,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010005, // Categorical splits. verReadableCur: 0x00010004, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(FastTreeRankingPredictor).Assembly.FullName); } protected override uint VerNumFeaturesSerialized => 0x00010002; diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs index 717824b6f5..fde54781d9 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs @@ -447,7 +447,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010005, // Categorical splits. verReadableCur: 0x00010004, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(FastTreeRegressionPredictor).Assembly.FullName); } protected override uint VerNumFeaturesSerialized => 0x00010002; diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs index eedbfe7389..c028ef963a 100644 --- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs +++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs @@ -442,7 +442,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010003, // Categorical splits. verReadableCur: 0x00010002, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(FastTreeTweediePredictor).Assembly.FullName); } protected override uint VerNumFeaturesSerialized => 0x00010001; diff --git a/src/Microsoft.ML.FastTree/GamClassification.cs b/src/Microsoft.ML.FastTree/GamClassification.cs index 1448333764..184cbdc565 100644 --- a/src/Microsoft.ML.FastTree/GamClassification.cs +++ b/src/Microsoft.ML.FastTree/GamClassification.cs @@ -132,7 +132,8 @@ public static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(BinaryClassGamPredictor).Assembly.FullName); } public static IPredictorProducing Create(IHostEnvironment env, ModelLoadContext ctx) diff --git a/src/Microsoft.ML.FastTree/GamRegression.cs b/src/Microsoft.ML.FastTree/GamRegression.cs index 119e8c2a85..e55a3ee008 100644 --- a/src/Microsoft.ML.FastTree/GamRegression.cs +++ b/src/Microsoft.ML.FastTree/GamRegression.cs @@ -87,7 +87,8 @@ public static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(RegressionGamPredictor).Assembly.FullName); } public static RegressionGamPredictor Create(IHostEnvironment env, ModelLoadContext ctx) diff --git a/src/Microsoft.ML.FastTree/GamTrainer.cs b/src/Microsoft.ML.FastTree/GamTrainer.cs index 0deceb1aff..43bbb60e15 100644 --- a/src/Microsoft.ML.FastTree/GamTrainer.cs +++ b/src/Microsoft.ML.FastTree/GamTrainer.cs @@ -956,11 +956,11 @@ public sealed class Arguments : DataCommand.ArgumentsBase [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to open the GAM visualization page URL", ShortName = "o", SortOrder = 3)] public bool Open = true; - internal Arguments SetServerIfNeeded(IExceptionContext ectx) + internal Arguments SetServerIfNeeded(IHostEnvironment env) { // We assume that if someone invoked this, they really did mean to start the web server. - if (ectx != null && Server == null) - Server = ServerChannel.CreateDefaultServerFactoryOrNull(ectx); + if (env != null && Server == null) + Server = ServerChannel.CreateDefaultServerFactoryOrNull(env); return this; } } diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs index b5ba1f43e7..fe79ab5d46 100644 --- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs +++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs @@ -63,7 +63,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010006, // Categorical splits. verReadableCur: 0x00010005, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(FastForestClassificationPredictor).Assembly.FullName); } protected override uint VerNumFeaturesSerialized => 0x00010003; diff --git a/src/Microsoft.ML.FastTree/RandomForestRegression.cs b/src/Microsoft.ML.FastTree/RandomForestRegression.cs index 1261cd0e37..68294596e4 100644 --- a/src/Microsoft.ML.FastTree/RandomForestRegression.cs +++ b/src/Microsoft.ML.FastTree/RandomForestRegression.cs @@ -49,7 +49,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010006, // Categorical splits. verReadableCur: 0x00010005, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(FastForestRegressionPredictor).Assembly.FullName); } protected override uint VerNumFeaturesSerialized => 0x00010003; diff --git a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs index fcd2550147..6faeba3cd8 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs @@ -418,7 +418,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010002, // Add _defaultValueForMissing verReadableCur: 0x00010002, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(TreeEnsembleFeaturizerBindableMapper).Assembly.FullName); } private readonly IHost _host; diff --git a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs index cb7472633f..879a32b95f 100644 --- a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs +++ b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs @@ -500,7 +500,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(OlsLinearRegressionPredictor).Assembly.FullName); } // The following will be null iff RSquaredAdjusted is NaN. diff --git a/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs index 5744fe384d..a137dcfd6a 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs @@ -74,7 +74,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(ImageGrayscaleTransform).Assembly.FullName); } private const string RegistrationName = "ImageGrayscale"; diff --git a/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs index 067cd30747..90e3fd6b3e 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs @@ -141,7 +141,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010002, // Swith from OpenCV to Bitmap verReadableCur: 0x00010002, verWeCanReadBack: 0x00010002, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(ImageLoaderTransform).Assembly.FullName); } protected override IRowMapper MakeRowMapper(ISchema schema) diff --git a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs index 1e563fbc2c..c2cd0f1a83 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs @@ -297,7 +297,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010002, // Swith from OpenCV to Bitmap verReadableCur: 0x00010002, verWeCanReadBack: 0x00010002, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(ImagePixelExtractorTransform).Assembly.FullName); } private const string RegistrationName = "ImagePixelExtractor"; diff --git a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs index 1755430cf3..2de1354205 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs @@ -154,7 +154,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010003, // No more sizeof(float) verReadableCur: 0x00010003, verWeCanReadBack: 0x00010003, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(ImageResizerTransform).Assembly.FullName); } private const string RegistrationName = "ImageScaler"; diff --git a/src/Microsoft.ML.ImageAnalytics/VectorToImageTransform.cs b/src/Microsoft.ML.ImageAnalytics/VectorToImageTransform.cs index 2446b7a7b6..a6f69204c4 100644 --- a/src/Microsoft.ML.ImageAnalytics/VectorToImageTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/VectorToImageTransform.cs @@ -236,7 +236,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010002, // Swith from OpenCV to Bitmap verReadableCur: 0x00010002, verWeCanReadBack: 0x00010002, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(VectorToImageTransform).Assembly.FullName); } private const string RegistrationName = "VectorToImageConverter"; diff --git a/src/Microsoft.ML.KMeansClustering/KMeansPredictor.cs b/src/Microsoft.ML.KMeansClustering/KMeansPredictor.cs index e3249da34d..e1b54f472a 100644 --- a/src/Microsoft.ML.KMeansClustering/KMeansPredictor.cs +++ b/src/Microsoft.ML.KMeansClustering/KMeansPredictor.cs @@ -38,7 +38,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010002, // Allow sparse centroids verReadableCur: 0x00010002, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(KMeansPredictor).Assembly.FullName); } public override PredictionKind PredictionKind => PredictionKind.Clustering; diff --git a/src/Microsoft.ML.Legacy/Microsoft.ML.Legacy.csproj b/src/Microsoft.ML.Legacy/Microsoft.ML.Legacy.csproj index bd3d72ab68..a2eeeb046c 100644 --- a/src/Microsoft.ML.Legacy/Microsoft.ML.Legacy.csproj +++ b/src/Microsoft.ML.Legacy/Microsoft.ML.Legacy.csproj @@ -6,6 +6,10 @@ CORECLR + + + + diff --git a/src/Microsoft.ML.Legacy/PredictionModel.cs b/src/Microsoft.ML.Legacy/PredictionModel.cs index 45c4738fe2..29f1bf35e9 100644 --- a/src/Microsoft.ML.Legacy/PredictionModel.cs +++ b/src/Microsoft.ML.Legacy/PredictionModel.cs @@ -127,6 +127,8 @@ public static Task> ReadAsync( using (var environment = new ConsoleEnvironment()) { + AssemblyLoadingUtils.RegisterCurrentLoadedAssemblies(environment); + BatchPredictionEngine predictor = environment.CreateBatchPredictionEngine(stream); diff --git a/src/Microsoft.ML.Legacy/Runtime/Experiment/Experiment.cs b/src/Microsoft.ML.Legacy/Runtime/Experiment/Experiment.cs index 108befb74b..efb9cb33d0 100644 --- a/src/Microsoft.ML.Legacy/Runtime/Experiment/Experiment.cs +++ b/src/Microsoft.ML.Legacy/Runtime/Experiment/Experiment.cs @@ -37,6 +37,8 @@ private sealed class SerializationHelper public Experiment(Runtime.IHostEnvironment env) { _env = env; + AssemblyLoadingUtils.RegisterCurrentLoadedAssemblies(_env); + _catalog = ModuleCatalog.CreateInstance(_env); _jsonNodes = new List(); _serializer = new JsonSerializer(); diff --git a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs index f788b4feab..6fe77564b7 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs @@ -40,7 +40,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010005, // Categorical splits. verReadableCur: 0x00010004, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(LightGbmBinaryPredictor).Assembly.FullName); } protected override uint VerNumFeaturesSerialized => 0x00010002; diff --git a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs index 3fe4628182..7373fbbac9 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs @@ -38,7 +38,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010005, // Categorical splits. verReadableCur: 0x00010004, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(LightGbmRankingPredictor).Assembly.FullName); } protected override uint VerNumFeaturesSerialized => 0x00010002; diff --git a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs index 0011a8d8e6..08331cb45f 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs @@ -38,7 +38,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010005, // Categorical splits. verReadableCur: 0x00010004, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(LightGbmRegressionPredictor).Assembly.FullName); } protected override uint VerNumFeaturesSerialized => 0x00010002; diff --git a/src/Microsoft.ML.Maml/HelpCommand.cs b/src/Microsoft.ML.Maml/HelpCommand.cs index 401cb7dd10..6c6d800ac7 100644 --- a/src/Microsoft.ML.Maml/HelpCommand.cs +++ b/src/Microsoft.ML.Maml/HelpCommand.cs @@ -100,7 +100,7 @@ public void Run() public void Run(int? columns) { - ComponentCatalog.CacheClassesExtra(_extraAssemblies); + AssemblyLoadingUtils.LoadAndRegister(_env, _extraAssemblies); using (var ch = _env.Start("Help")) using (var sw = new StringWriter(CultureInfo.InvariantCulture)) @@ -137,7 +137,7 @@ private void ShowHelp(IndentingTextWriter writer, int? columns = null) // Note that we don't check IsHidden here. The current policy is when IsHidden is true, we don't // show the item in "list all" functionality, but will still show help when explicitly requested. - var infos = ComponentCatalog.FindLoadableClasses(name) + var infos = _env.ComponentCatalog.FindLoadableClasses(name) .OrderBy(x => ComponentCatalog.SignatureToString(x.SignatureTypes[0]).ToLowerInvariant()); var kinds = new StringBuilder(); var components = new List(); @@ -188,7 +188,7 @@ private void ShowAllHelp(IndentingTextWriter writer, int? columns = null) { string sig = _kind?.ToLowerInvariant(); - var infos = ComponentCatalog.GetAllClasses() + var infos = _env.ComponentCatalog.GetAllClasses() .OrderBy(info => info.LoadNames[0].ToLowerInvariant()) .ThenBy(info => ComponentCatalog.SignatureToString(info.SignatureTypes[0]).ToLowerInvariant()); var components = new List(); @@ -256,7 +256,7 @@ private void ShowComponents(IndentingTextWriter writer) else { kind = _kind.ToLowerInvariant(); - var sigs = ComponentCatalog.GetAllSignatureTypes(); + var sigs = _env.ComponentCatalog.GetAllSignatureTypes(); typeSig = sigs.FirstOrDefault(t => ComponentCatalog.SignatureToString(t).ToLowerInvariant() == kind); if (typeSig == null) { @@ -272,7 +272,7 @@ private void ShowComponents(IndentingTextWriter writer) writer.WriteLine("Available components for kind '{0}':", ComponentCatalog.SignatureToString(typeSig)); } - var infos = ComponentCatalog.GetAllDerivedClasses(typeRes, typeSig) + var infos = _env.ComponentCatalog.GetAllDerivedClasses(typeRes, typeSig) .Where(x => !x.IsHidden) .OrderBy(x => x.LoadNames[0].ToLowerInvariant()); using (writer.Nest()) @@ -322,7 +322,7 @@ private void ShowAliases(IndentingTextWriter writer, IReadOnlyList names private void ListKinds(IndentingTextWriter writer) { - var sigs = ComponentCatalog.GetAllSignatureTypes() + var sigs = _env.ComponentCatalog.GetAllSignatureTypes() .Select(ComponentCatalog.SignatureToString) .OrderBy(x => x); diff --git a/src/Microsoft.ML.Maml/MAML.cs b/src/Microsoft.ML.Maml/MAML.cs index cac407c21a..c88ee6112c 100644 --- a/src/Microsoft.ML.Maml/MAML.cs +++ b/src/Microsoft.ML.Maml/MAML.cs @@ -55,7 +55,10 @@ public static int Main() private static int MainWithProgress(string args) { + string currentDirectory = Path.GetDirectoryName(typeof(Maml).Module.FullyQualifiedName); + using (var env = CreateEnvironment()) + using (AssemblyLoadingUtils.CreateAssemblyRegistrar(env, currentDirectory)) using (var progressCancel = new CancellationTokenSource()) { var progressTrackerTask = Task.Run(() => TrackProgress(env, progressCancel.Token)); diff --git a/src/Microsoft.ML.Maml/Microsoft.ML.Maml.csproj b/src/Microsoft.ML.Maml/Microsoft.ML.Maml.csproj index 88b1e9e55b..219b6bf0b8 100644 --- a/src/Microsoft.ML.Maml/Microsoft.ML.Maml.csproj +++ b/src/Microsoft.ML.Maml/Microsoft.ML.Maml.csproj @@ -3,10 +3,14 @@ true CORECLR - Microsoft.ML - netstandard2.0 + Microsoft.ML + netstandard2.0 + + + + diff --git a/src/Microsoft.ML.PCA/PcaTrainer.cs b/src/Microsoft.ML.PCA/PcaTrainer.cs index 2654af5d13..b784f1f6ee 100644 --- a/src/Microsoft.ML.PCA/PcaTrainer.cs +++ b/src/Microsoft.ML.PCA/PcaTrainer.cs @@ -307,7 +307,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(PcaPredictor).Assembly.FullName); } private readonly int _dimension; diff --git a/src/Microsoft.ML.PCA/PcaTransform.cs b/src/Microsoft.ML.PCA/PcaTransform.cs index abb4b9b821..dd1596f330 100644 --- a/src/Microsoft.ML.PCA/PcaTransform.cs +++ b/src/Microsoft.ML.PCA/PcaTransform.cs @@ -184,7 +184,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(PcaTransform).Assembly.FullName); } // These are parallel to Infos. diff --git a/src/Microsoft.ML.Parquet/ParquetLoader.cs b/src/Microsoft.ML.Parquet/ParquetLoader.cs index e131949138..667600c44e 100644 --- a/src/Microsoft.ML.Parquet/ParquetLoader.cs +++ b/src/Microsoft.ML.Parquet/ParquetLoader.cs @@ -109,7 +109,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010002, // Add Schema to Model Context verReadableCur: 0x00010002, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(ParquetLoader).Assembly.FullName); } public ParquetLoader(IHostEnvironment env, Arguments args, IMultiStreamSource files) diff --git a/src/Microsoft.ML.PipelineInference/ExperimentsGenerator.cs b/src/Microsoft.ML.PipelineInference/ExperimentsGenerator.cs index 053b86a40b..ee0fa11bde 100644 --- a/src/Microsoft.ML.PipelineInference/ExperimentsGenerator.cs +++ b/src/Microsoft.ML.PipelineInference/ExperimentsGenerator.cs @@ -110,7 +110,7 @@ public static List GenerateCandidates(IHostEnvironment env, string dataFi //get all the trainers for this task, and generate the initial set of candidates. // Exclude the hidden learners, and the metalinear learners. - var trainers = ComponentCatalog.GetAllDerivedClasses(typeof(ITrainer), predictorType).Where(cls => !cls.IsHidden); + var trainers = env.ComponentCatalog.GetAllDerivedClasses(typeof(ITrainer), predictorType).Where(cls => !cls.IsHidden); if (!string.IsNullOrEmpty(loaderSettings)) { diff --git a/src/Microsoft.ML.PipelineInference/RecipeInference.cs b/src/Microsoft.ML.PipelineInference/RecipeInference.cs index 67f3a7f02d..4363c8a4d7 100644 --- a/src/Microsoft.ML.PipelineInference/RecipeInference.cs +++ b/src/Microsoft.ML.PipelineInference/RecipeInference.cs @@ -128,13 +128,13 @@ public InferenceResult(SuggestedRecipe[] suggestedRecipes) } } - private static IEnumerable GetRecipes() + private static IEnumerable GetRecipes(IHostEnvironment env) { yield return new DefaultRecipe(); - yield return new BalancedTextClassificationRecipe(); - yield return new AccuracyFocusedRecipe(); - yield return new ExplorationComboRecipe(); - yield return new TreeLeafRecipe(); + yield return new BalancedTextClassificationRecipe(env); + yield return new AccuracyFocusedRecipe(env); + yield return new ExplorationComboRecipe(env); + yield return new TreeLeafRecipe(env); } public abstract class Recipe @@ -210,6 +210,14 @@ protected override IEnumerable ApplyCore(Type predictorType, public abstract class MultiClassRecipies : Recipe { + protected IHostEnvironment Host { get; } + + protected MultiClassRecipies(IHostEnvironment host) + { + Contracts.CheckValue(host, nameof(host)); + Host = host; + } + public override List AllowedTransforms() => base.AllowedTransforms().Where( expert => expert != typeof(TransformInference.Experts.Text) && @@ -227,6 +235,11 @@ public override List AllowedTransforms() => base.AllowedTransforms().Where public sealed class BalancedTextClassificationRecipe : MultiClassRecipies { + public BalancedTextClassificationRecipe(IHostEnvironment host) + : base(host) + { + } + public override List QualifierTransforms() => new List { typeof(TransformInference.Experts.TextBiGramTriGram) }; @@ -236,12 +249,12 @@ protected override IEnumerable ApplyCore(Type predictorType, SuggestedRecipe.SuggestedLearner learner = new SuggestedRecipe.SuggestedLearner(); if (predictorType == typeof(SignatureMultiClassClassifierTrainer)) { - learner.LoadableClassInfo = ComponentCatalog.GetLoadableClassInfo("OVA"); + learner.LoadableClassInfo = Host.ComponentCatalog.GetLoadableClassInfo("OVA"); learner.Settings = "p=AveragedPerceptron{iter=10}"; } else { - learner.LoadableClassInfo = ComponentCatalog.GetLoadableClassInfo(Learners.AveragedPerceptronTrainer.LoadNameValue); + learner.LoadableClassInfo = Host.ComponentCatalog.GetLoadableClassInfo(Learners.AveragedPerceptronTrainer.LoadNameValue); learner.Settings = "iter=10"; var epInput = new Legacy.Trainers.AveragedPerceptronBinaryClassifier { @@ -259,6 +272,11 @@ protected override IEnumerable ApplyCore(Type predictorType, public sealed class AccuracyFocusedRecipe : MultiClassRecipies { + public AccuracyFocusedRecipe(IHostEnvironment host) + : base(host) + { + } + public override List QualifierTransforms() => new List { typeof(TransformInference.Experts.TextUniGramTriGram) }; @@ -268,13 +286,13 @@ protected override IEnumerable ApplyCore(Type predictorType, SuggestedRecipe.SuggestedLearner learner = new SuggestedRecipe.SuggestedLearner(); if (predictorType == typeof(SignatureMultiClassClassifierTrainer)) { - learner.LoadableClassInfo = ComponentCatalog.GetLoadableClassInfo("OVA"); + learner.LoadableClassInfo = Host.ComponentCatalog.GetLoadableClassInfo("OVA"); learner.Settings = "p=FastTreeBinaryClassification"; } else { learner.LoadableClassInfo = - ComponentCatalog.GetLoadableClassInfo(FastTreeBinaryClassificationTrainer.LoadNameValue); + Host.ComponentCatalog.GetLoadableClassInfo(FastTreeBinaryClassificationTrainer.LoadNameValue); learner.Settings = ""; var epInput = new Legacy.Trainers.FastTreeBinaryClassifier(); learner.PipelineNode = new TrainerPipelineNode(epInput); @@ -288,6 +306,11 @@ protected override IEnumerable ApplyCore(Type predictorType, public sealed class ExplorationComboRecipe : MultiClassRecipies { + public ExplorationComboRecipe(IHostEnvironment host) + : base(host) + { + } + public override List QualifierTransforms() => new List { typeof(TransformInference.Experts.SdcaTransform) }; @@ -298,12 +321,12 @@ protected override IEnumerable ApplyCore(Type predictorType, if (predictorType == typeof(SignatureMultiClassClassifierTrainer)) { learner.LoadableClassInfo = - ComponentCatalog.GetLoadableClassInfo(Learners.SdcaMultiClassTrainer.LoadNameValue); + Host.ComponentCatalog.GetLoadableClassInfo(Learners.SdcaMultiClassTrainer.LoadNameValue); } else { learner.LoadableClassInfo = - ComponentCatalog.GetLoadableClassInfo(Learners.LinearClassificationTrainer.LoadNameValue); + Host.ComponentCatalog.GetLoadableClassInfo(Learners.LinearClassificationTrainer.LoadNameValue); var epInput = new Legacy.Trainers.StochasticDualCoordinateAscentBinaryClassifier(); learner.PipelineNode = new TrainerPipelineNode(epInput); } @@ -317,6 +340,11 @@ protected override IEnumerable ApplyCore(Type predictorType, public sealed class TreeLeafRecipe : MultiClassRecipies { + public TreeLeafRecipe(IHostEnvironment host) + : base(host) + { + } + public override List QualifierTransforms() => new List { typeof(TransformInference.Experts.NaiveBayesTransform) }; @@ -325,7 +353,7 @@ protected override IEnumerable ApplyCore(Type predictorType, { SuggestedRecipe.SuggestedLearner learner = new SuggestedRecipe.SuggestedLearner(); learner.LoadableClassInfo = - ComponentCatalog.GetLoadableClassInfo(Learners.MultiClassNaiveBayesTrainer.LoadName); + Host.ComponentCatalog.GetLoadableClassInfo(Learners.MultiClassNaiveBayesTrainer.LoadName); learner.Settings = ""; var epInput = new Legacy.Trainers.NaiveBayesClassifier(); learner.PipelineNode = new TrainerPipelineNode(epInput); @@ -403,7 +431,7 @@ public static SuggestedRecipe[] InferRecipesFromData(IHostEnvironment env, strin AllowQuoting = splitResult.AllowQuote }; - settingsString = CommandLine.CmdParser.GetSettings(ch, finalLoaderArgs, new TextLoader.Arguments()); + settingsString = CommandLine.CmdParser.GetSettings(h, finalLoaderArgs, new TextLoader.Arguments()); ch.Info($"Loader options: {settingsString}"); ch.Info("Inferring recipes"); @@ -440,7 +468,7 @@ public static InferenceResult InferRecipes(IHostEnvironment env, TransformInfere using (var ch = h.Start("InferRecipes")) { var list = new List(); - foreach (var recipe in GetRecipes()) + foreach (var recipe in GetRecipes(h)) list.AddRange(recipe.Apply(transformInferenceResult, predictorType, ch)); if (list.Count == 0) diff --git a/src/Microsoft.ML.ResultProcessor/Microsoft.ML.ResultProcessor.csproj b/src/Microsoft.ML.ResultProcessor/Microsoft.ML.ResultProcessor.csproj index b420da0eb0..e5610126df 100644 --- a/src/Microsoft.ML.ResultProcessor/Microsoft.ML.ResultProcessor.csproj +++ b/src/Microsoft.ML.ResultProcessor/Microsoft.ML.ResultProcessor.csproj @@ -1,4 +1,4 @@ - + netstandard2.0 @@ -7,6 +7,10 @@ true + + + + diff --git a/src/Microsoft.ML.ResultProcessor/ResultProcessor.cs b/src/Microsoft.ML.ResultProcessor/ResultProcessor.cs index eb42e452bd..b142cea0ba 100644 --- a/src/Microsoft.ML.ResultProcessor/ResultProcessor.cs +++ b/src/Microsoft.ML.ResultProcessor/ResultProcessor.cs @@ -12,7 +12,6 @@ using Microsoft.ML.Runtime.Command; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.Tools; @@ -152,9 +151,9 @@ public void GetDefaultSettingValues(IHostEnvironment env, string predictorName, /// private Dictionary GetDefaultSettings(IHostEnvironment env, string predictorName, string[] extraAssemblies = null) { - ComponentCatalog.CacheClassesExtra(extraAssemblies); + AssemblyLoadingUtils.LoadAndRegister(env, extraAssemblies); - var cls = ComponentCatalog.GetLoadableClassInfo(predictorName); + var cls = env.ComponentCatalog.GetLoadableClassInfo(predictorName); if (cls == null) { Console.Error.WriteLine("Can't load trainer '{0}'", predictorName); @@ -521,7 +520,7 @@ private static bool ValidateMamlOutput(string filename, string[] rawLines, out L ICommandLineComponentFactory commandLineTrainer = trainer as ICommandLineComponentFactory; Contracts.AssertValue(commandLineTrainer, "ResultProcessor can only work with ICommandLineComponentFactory."); - trainerClass = ComponentCatalog.GetLoadableClassInfo(commandLineTrainer.Name); + trainerClass = env.ComponentCatalog.GetLoadableClassInfo(commandLineTrainer.Name); trainerArgs = trainerClass.CreateArguments(); Dictionary predictorSettings; if (trainerArgs == null) @@ -678,7 +677,12 @@ public static bool ParseCommandArguments(IHostEnvironment env, string commandlin return false; } - commandClass = ComponentCatalog.GetLoadableClassInfo(kind); + commandClass = env.ComponentCatalog.GetLoadableClassInfo(kind); + if (commandClass == null) + { + commandArgs = null; + return false; + } commandArgs = commandClass.CreateArguments(); CmdParser.ParseArguments(env, settings, commandArgs); return true; @@ -1147,10 +1151,15 @@ private static object Load(Stream stream) } public static int Main(string[] args) + { + return Main(new ConsoleEnvironment(42), args); + } + + public static int Main(IHostEnvironment env, string[] args) { try { - Run(args); + Run(env, args); return 0; } catch (Exception e) @@ -1170,10 +1179,9 @@ public static int Main(string[] args) } } - protected static void Run(string[] args) + protected static void Run(IHostEnvironment env, string[] args) { ResultProcessorArguments cmd = new ResultProcessorArguments(); - ConsoleEnvironment env = new ConsoleEnvironment(42); List predictorResultsList = new List(); PredictionUtil.ParseArguments(env, cmd, PredictionUtil.CombineSettings(args)); @@ -1185,8 +1193,8 @@ protected static void Run(string[] args) if (cmd.IncludePerFoldResults) cmd.PerFoldResultSeparator = "" + PredictionUtil.SepCharFromString(cmd.PerFoldResultSeparator); - foreach (var dll in cmd.ExtraAssemblies) - ComponentCatalog.LoadAssembly(dll); + + AssemblyLoadingUtils.LoadAndRegister(env, cmd.ExtraAssemblies); if (cmd.Metrics.Length == 0) cmd.Metrics = null; diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs index f5da9327c1..8e95682b51 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs @@ -40,7 +40,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(FieldAwareFactorizationMachinePredictor).Assembly.FullName); } internal FieldAwareFactorizationMachinePredictor(IHostEnvironment env, bool norm, int fieldCount, int featureCount, int latentDim, @@ -353,7 +354,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(FieldAwareFactorizationMachinePredictionTransformer).Assembly.FullName); } private static FieldAwareFactorizationMachinePredictionTransformer Create(IHostEnvironment env, ModelLoadContext ctx) diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs index eb0644d2da..ed41a14174 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs @@ -409,7 +409,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00020002, // Added model statistics verReadableCur: 0x00020001, verWeCanReadBack: 0x00020001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(LinearBinaryPredictor).Assembly.FullName); } /// @@ -597,7 +598,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00020001, // Fixed sparse serialization verReadableCur: 0x00020001, verWeCanReadBack: 0x00020001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(LinearRegressionPredictor).Assembly.FullName); } /// @@ -679,7 +681,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00020001, // Fixed sparse serialization verReadableCur: 0x00020001, verWeCanReadBack: 0x00020001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(PoissonRegressionPredictor).Assembly.FullName); } internal PoissonRegressionPredictor(IHostEnvironment env, ref VBuffer weights, Float bias) diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs index b8986cce77..2a90e6cb2a 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs @@ -304,7 +304,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010003, // Added model stats verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(MulticlassLogisticRegressionPredictor).Assembly.FullName); } private const string ModelStatsSubModelFilename = "ModelStats"; diff --git a/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs b/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs index e1594041f8..cc28c1caf1 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs @@ -59,7 +59,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(LinearModelStatistics).Assembly.FullName); } private readonly IHostEnvironment _env; diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs index 8c96ee1e0b..0817db47a5 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs @@ -151,7 +151,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(MultiClassNaiveBayesPredictor).Assembly.FullName); } private readonly int[] _labelHistogram; diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs index b918c5acce..6229e9f70f 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs @@ -214,7 +214,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(OvaPredictor).Assembly.FullName); } private const string SubPredictorFmt = "SubPredictor_{0:000}"; diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs index 6a8736bad1..da9bfc99b6 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs @@ -225,7 +225,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(PkpdPredictor).Assembly.FullName); } private const string SubPredictorFmt = "SubPredictor_{0:000}"; diff --git a/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs b/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs index e823d1fcd1..804943ac0d 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs @@ -120,7 +120,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(RandomPredictor).Assembly.FullName); } // Keep all the serializable state here. @@ -358,7 +359,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(PriorPredictor).Assembly.FullName); } private readonly float _prob; diff --git a/src/Microsoft.ML.Sweeper/SweepResultEvaluator.cs b/src/Microsoft.ML.Sweeper/SweepResultEvaluator.cs index ec80ce89b1..fded15ddaf 100644 --- a/src/Microsoft.ML.Sweeper/SweepResultEvaluator.cs +++ b/src/Microsoft.ML.Sweeper/SweepResultEvaluator.cs @@ -42,7 +42,7 @@ private string FindMetric(string userMetric, out bool maximizing) { StringBuilder sb = new StringBuilder(); - var evaluators = ComponentCatalog.GetAllDerivedClasses(typeof(IMamlEvaluator), typeof(SignatureMamlEvaluator)); + var evaluators = _host.ComponentCatalog.GetAllDerivedClasses(typeof(IMamlEvaluator), typeof(SignatureMamlEvaluator)); foreach (var evalInfo in evaluators) { var args = evalInfo.CreateArguments(); diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs index 69532de8fc..0afb03b7a2 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs @@ -76,7 +76,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010002, // Upgraded when change for multiple outputs was implemented. verReadableCur: 0x00010002, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(TensorFlowTransform).Assembly.FullName); } /// diff --git a/src/Microsoft.ML.Transforms/BootstrapSampleTransform.cs b/src/Microsoft.ML.Transforms/BootstrapSampleTransform.cs index 91106bd445..dc4bde762c 100644 --- a/src/Microsoft.ML.Transforms/BootstrapSampleTransform.cs +++ b/src/Microsoft.ML.Transforms/BootstrapSampleTransform.cs @@ -59,7 +59,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(BootstrapSampleTransform).Assembly.FullName); } internal const string RegistrationName = "BootstrapSample"; diff --git a/src/Microsoft.ML.Transforms/CompositeTransform.cs b/src/Microsoft.ML.Transforms/CompositeTransform.cs index 4dfeb3b312..48cdfe345f 100644 --- a/src/Microsoft.ML.Transforms/CompositeTransform.cs +++ b/src/Microsoft.ML.Transforms/CompositeTransform.cs @@ -30,7 +30,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(CompositeTransform).Assembly.FullName); } public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) diff --git a/src/Microsoft.ML.Transforms/FourierDistributionSampler.cs b/src/Microsoft.ML.Transforms/FourierDistributionSampler.cs index fbb544da50..ae9a721bd9 100644 --- a/src/Microsoft.ML.Transforms/FourierDistributionSampler.cs +++ b/src/Microsoft.ML.Transforms/FourierDistributionSampler.cs @@ -55,7 +55,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(GaussianFourierSampler).Assembly.FullName); } public const string LoadName = "GaussianRandom"; @@ -130,7 +131,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(LaplacianFourierSampler).Assembly.FullName); } public const string LoaderSignature = "RandLaplacianFourierExec"; diff --git a/src/Microsoft.ML.Transforms/GcnTransform.cs b/src/Microsoft.ML.Transforms/GcnTransform.cs index 67a6bbd951..d8d1618a41 100644 --- a/src/Microsoft.ML.Transforms/GcnTransform.cs +++ b/src/Microsoft.ML.Transforms/GcnTransform.cs @@ -236,7 +236,8 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010002, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderSignatureAlt: LoaderSignatureOld); + loaderSignatureAlt: LoaderSignatureOld, + loaderAssemblyName: typeof(LpNormNormalizerTransform).Assembly.FullName); } private const string RegistrationName = "LpNormNormalizer"; diff --git a/src/Microsoft.ML.Transforms/GroupTransform.cs b/src/Microsoft.ML.Transforms/GroupTransform.cs index c3a0bf8736..a40a9ea827 100644 --- a/src/Microsoft.ML.Transforms/GroupTransform.cs +++ b/src/Microsoft.ML.Transforms/GroupTransform.cs @@ -65,7 +65,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(GroupTransform).Assembly.FullName); } // REVIEW: maybe we want to have an option to keep all non-group scalar columns, as opposed to diff --git a/src/Microsoft.ML.Transforms/HashJoinTransform.cs b/src/Microsoft.ML.Transforms/HashJoinTransform.cs index d8e58c84ed..bd3d007fa8 100644 --- a/src/Microsoft.ML.Transforms/HashJoinTransform.cs +++ b/src/Microsoft.ML.Transforms/HashJoinTransform.cs @@ -168,7 +168,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010005, // Hash fix verReadableCur: 0x00010005, verWeCanReadBack: 0x00010005, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(HashJoinTransform).Assembly.FullName); } private readonly ColumnInfoEx[] _exes; diff --git a/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs b/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs index f1503c336a..4f947415af 100644 --- a/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs +++ b/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs @@ -59,7 +59,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00000001, // Initial verReadableCur: 0x00000001, verWeCanReadBack: 0x00000001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(KeyToBinaryVectorTransform).Assembly.FullName); } private const string RegistrationName = "KeyToBinary"; diff --git a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransform.cs b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransform.cs index 0a8d247766..41412e6e26 100644 --- a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransform.cs +++ b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransform.cs @@ -59,7 +59,8 @@ private static VersionInfo GetVersionInfo() loaderSignature: LoaderSignature, // This is an older name and can be removed once we don't care about old code // being able to load this. - loaderSignatureAlt: "MissingFeatureFunction"); + loaderSignatureAlt: "MissingFeatureFunction", + loaderAssemblyName: typeof(MissingValueIndicatorTransform).Assembly.FullName); } private const string RegistrationName = "MissingIndicator"; diff --git a/src/Microsoft.ML.Transforms/NADropTransform.cs b/src/Microsoft.ML.Transforms/NADropTransform.cs index ea47fb9d1a..6d67fa5efe 100644 --- a/src/Microsoft.ML.Transforms/NADropTransform.cs +++ b/src/Microsoft.ML.Transforms/NADropTransform.cs @@ -59,7 +59,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(NADropTransform).Assembly.FullName); } private const string RegistrationName = "DropNAs"; diff --git a/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs b/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs index 39b3d650d2..044751f459 100644 --- a/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs +++ b/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs @@ -59,7 +59,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(NAIndicatorTransform).Assembly.FullName); } internal const string Summary = "Create a boolean output column with the same number of slots as the input column, where the output value" diff --git a/src/Microsoft.ML.Transforms/NAReplaceTransform.cs b/src/Microsoft.ML.Transforms/NAReplaceTransform.cs index 1f3b4ee220..16ebdcae6a 100644 --- a/src/Microsoft.ML.Transforms/NAReplaceTransform.cs +++ b/src/Microsoft.ML.Transforms/NAReplaceTransform.cs @@ -138,7 +138,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x0010002, // Added imputation methods. verReadableCur: 0x00010002, verWeCanReadBack: 0x00010001, - loaderSignature: LoadName); + loaderSignature: LoadName, + loaderAssemblyName: typeof(NAReplaceTransform).Assembly.FullName); } internal const string Summary = "Create an output column of the same type and size of the input column, where missing values " diff --git a/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs b/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs index 5117496194..b78f008215 100644 --- a/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs +++ b/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs @@ -226,7 +226,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010002, // Save the input schema, for metadata verReadableCur: 0x00010002, verWeCanReadBack: 0x00010002, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(OptionalColumnTransform).Assembly.FullName); } private readonly Bindings _bindings; diff --git a/src/Microsoft.ML.Transforms/ProduceIdTransform.cs b/src/Microsoft.ML.Transforms/ProduceIdTransform.cs index 0934fd0086..b87476cefa 100644 --- a/src/Microsoft.ML.Transforms/ProduceIdTransform.cs +++ b/src/Microsoft.ML.Transforms/ProduceIdTransform.cs @@ -84,7 +84,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(ProduceIdTransform).Assembly.FullName); } private readonly Bindings _bindings; diff --git a/src/Microsoft.ML.Transforms/RffTransform.cs b/src/Microsoft.ML.Transforms/RffTransform.cs index 7cb51a6f65..866683bb99 100644 --- a/src/Microsoft.ML.Transforms/RffTransform.cs +++ b/src/Microsoft.ML.Transforms/RffTransform.cs @@ -220,7 +220,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(RffTransform).Assembly.FullName); } // These are parallel to Infos. diff --git a/src/Microsoft.ML.Transforms/TermLookupTransform.cs b/src/Microsoft.ML.Transforms/TermLookupTransform.cs index 98dc7a0933..9cadc31d9f 100644 --- a/src/Microsoft.ML.Transforms/TermLookupTransform.cs +++ b/src/Microsoft.ML.Transforms/TermLookupTransform.cs @@ -277,7 +277,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010002, // Dropped sizeof(Float). verReadableCur: 0x00010002, verWeCanReadBack: 0x00010002, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(TermLookupTransform).Assembly.FullName); } // This is the byte array containing the binary .idv file contents for the lookup data. diff --git a/src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs b/src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs index 47c6762d1b..da10ce2b65 100644 --- a/src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs @@ -75,7 +75,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010002, // Updated to use UnitSeparator character instead of using for vector inputs. verReadableCur: 0x00010002, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(CharTokenizeTransform).Assembly.FullName); } // Controls whether to mark the beginning/end of each row/slot with TextStartMarker/TextEndMarker. diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index 380399d96a..f15546d2cc 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -291,7 +291,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(LdaTransform).Assembly.FullName); } private readonly ColInfoEx[] _exes; diff --git a/src/Microsoft.ML.Transforms/Text/NgramHashTransform.cs b/src/Microsoft.ML.Transforms/Text/NgramHashTransform.cs index 8ab64f9da4..72eb1d650a 100644 --- a/src/Microsoft.ML.Transforms/Text/NgramHashTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/NgramHashTransform.cs @@ -317,7 +317,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010002, // Invert hash key values, hash fix verReadableCur: 0x00010002, verWeCanReadBack: 0x00010002, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(NgramHashTransform).Assembly.FullName); } private readonly Bindings _bindings; diff --git a/src/Microsoft.ML.Transforms/Text/NgramTransform.cs b/src/Microsoft.ML.Transforms/Text/NgramTransform.cs index a696f2b092..3f24f290f9 100644 --- a/src/Microsoft.ML.Transforms/Text/NgramTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/NgramTransform.cs @@ -200,7 +200,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010002, // Add support for TF-IDF verReadableCur: 0x00010002, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(NgramTransform).Assembly.FullName); } private readonly VectorType[] _types; diff --git a/src/Microsoft.ML.Transforms/Text/StopWordsRemoverTransform.cs b/src/Microsoft.ML.Transforms/Text/StopWordsRemoverTransform.cs index d278326b33..42fac1e2bb 100644 --- a/src/Microsoft.ML.Transforms/Text/StopWordsRemoverTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/StopWordsRemoverTransform.cs @@ -235,7 +235,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(StopWordsRemoverTransform).Assembly.FullName); } private readonly bool?[] _resourcesExist; @@ -602,10 +603,11 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(CustomStopWordsRemoverTransform).Assembly.FullName); } - public const string StopwrodsManagerLoaderSignature = "CustomStopWordsManager"; + public const string StopwordsManagerLoaderSignature = "CustomStopWordsManager"; private static VersionInfo GetStopwrodsManagerVersionInfo() { return new VersionInfo( @@ -613,7 +615,8 @@ private static VersionInfo GetStopwrodsManagerVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: StopwrodsManagerLoaderSignature); + loaderSignature: StopwordsManagerLoaderSignature, + loaderAssemblyName: typeof(CustomStopWordsRemoverTransform).Assembly.FullName); } private static readonly ColumnType _outputType = new VectorType(TextType.Instance); diff --git a/src/Microsoft.ML.Transforms/Text/TextNormalizerTransform.cs b/src/Microsoft.ML.Transforms/Text/TextNormalizerTransform.cs index c96b1511d1..87a33fe2bf 100644 --- a/src/Microsoft.ML.Transforms/Text/TextNormalizerTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/TextNormalizerTransform.cs @@ -86,7 +86,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(TextNormalizerTransform).Assembly.FullName); } private const string RegistrationName = "TextNormalizer"; diff --git a/src/Microsoft.ML.Transforms/Text/TextTransform.cs b/src/Microsoft.ML.Transforms/Text/TextTransform.cs index cc801dd311..77e7739ad7 100644 --- a/src/Microsoft.ML.Transforms/Text/TextTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/TextTransform.cs @@ -639,7 +639,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(Transformer).Assembly.FullName); } } diff --git a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsTransform.cs b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsTransform.cs index 5a5b382d75..b9ef52a414 100644 --- a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsTransform.cs @@ -79,7 +79,8 @@ public static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, //Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(WordEmbeddingsTransform).Assembly.FullName); } private readonly PretrainedModelKind? _modelKind; diff --git a/src/Microsoft.ML.Transforms/Text/WordTokenizeTransform.cs b/src/Microsoft.ML.Transforms/Text/WordTokenizeTransform.cs index 35bfbd925c..391cfa0327 100644 --- a/src/Microsoft.ML.Transforms/Text/WordTokenizeTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/WordTokenizeTransform.cs @@ -141,7 +141,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(DelimitedTokenizeTransform).Assembly.FullName); } public override bool CanSavePfa => true; diff --git a/src/Microsoft.ML.Transforms/UngroupTransform.cs b/src/Microsoft.ML.Transforms/UngroupTransform.cs index d77cc48bbd..cc49ffe8f3 100644 --- a/src/Microsoft.ML.Transforms/UngroupTransform.cs +++ b/src/Microsoft.ML.Transforms/UngroupTransform.cs @@ -58,7 +58,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(UngroupTransform).Assembly.FullName); } /// diff --git a/src/Microsoft.ML.Transforms/WhiteningTransform.cs b/src/Microsoft.ML.Transforms/WhiteningTransform.cs index 5d1b3188b7..21c0870633 100644 --- a/src/Microsoft.ML.Transforms/WhiteningTransform.cs +++ b/src/Microsoft.ML.Transforms/WhiteningTransform.cs @@ -211,7 +211,8 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderSignatureAlt: LoaderSignatureOld); + loaderSignatureAlt: LoaderSignatureOld, + loaderAssemblyName: typeof(WhiteningTransform).Assembly.FullName); } private readonly ColInfoEx[] _exes; diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs index b3d7068041..11407a3ec4 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs @@ -6,6 +6,7 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.TestFramework; +using Microsoft.ML.Transforms; using System; using System.Collections.Generic; using System.Linq; diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEarlyStoppingCriteria.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEarlyStoppingCriteria.cs index 53b77440dc..3798246fa8 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEarlyStoppingCriteria.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEarlyStoppingCriteria.cs @@ -13,8 +13,10 @@ public sealed class TestEarlyStoppingCriteria { private IEarlyStoppingCriterion CreateEarlyStoppingCriterion(string name, string args, bool lowerIsBetter) { + var env = new ConsoleEnvironment() + .AddStandardComponents(); var sub = new SubComponent(name, args); - return sub.CreateInstance(new ConsoleEnvironment(), lowerIsBetter); + return sub.CreateInstance(env, lowerIsBetter); } [Fact] diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index d674d99c61..d94192f12e 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -15,11 +15,17 @@ using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.EntryPoints.JsonUtils; using Microsoft.ML.Runtime.FastTree; +using Microsoft.ML.Runtime.ImageAnalytics; using Microsoft.ML.Runtime.Internal.Calibration; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.Runtime.LightGBM; +using Microsoft.ML.Runtime.Model.Onnx; using Microsoft.ML.Runtime.PCA; +using Microsoft.ML.Runtime.PipelineInference; +using Microsoft.ML.Runtime.SymSgd; using Microsoft.ML.Runtime.TextAnalytics; +using Microsoft.ML.Transforms; using Newtonsoft.Json; using Newtonsoft.Json.Linq; using Xunit; @@ -237,37 +243,19 @@ private string GetBuildPrefix() [Fact(Skip = "Execute this test if you want to regenerate ep-list and _manifest.json")] public void RegenerateEntryPointCatalog() { + var (epListContents, jObj) = BuildManifests(); + var buildPrefix = GetBuildPrefix(); var epListFile = buildPrefix + "_ep-list.tsv"; - var manifestFile = buildPrefix + "_manifest.json"; var entryPointsSubDir = Path.Combine("..", "Common", "EntryPoints"); var catalog = ModuleCatalog.CreateInstance(Env); var epListPath = GetBaselinePath(entryPointsSubDir, epListFile); DeleteOutputPath(epListPath); - var regex = new Regex(@"\r\n?|\n", RegexOptions.Compiled); - File.WriteAllLines(epListPath, catalog.AllEntryPoints() - .Select(x => string.Join("\t", - x.Name, - regex.Replace(x.Description, ""), - x.Method.DeclaringType, - x.Method.Name, - x.InputType, - x.OutputType) - .Replace(Environment.NewLine, "")) - .OrderBy(x => x)); - + File.WriteAllLines(epListPath, epListContents); - var jObj = JsonManifestUtils.BuildAllManifests(Env, catalog); - - //clean up the description from the new line characters - if (jObj[FieldNames.TopEntryPoints] != null && jObj[FieldNames.TopEntryPoints] is JArray) - { - foreach (JToken entry in jObj[FieldNames.TopEntryPoints].Children()) - if (entry[FieldNames.Desc] != null) - entry[FieldNames.Desc] = regex.Replace(entry[FieldNames.Desc].ToString(), ""); - } + var manifestFile = buildPrefix + "_manifest.json"; var manifestPath = GetBaselinePath(entryPointsSubDir, manifestFile); DeleteOutputPath(manifestPath); @@ -280,20 +268,49 @@ public void RegenerateEntryPointCatalog() } } - [Fact] public void EntryPointCatalog() { + var (epListContents, jObj) = BuildManifests(); + var buildPrefix = GetBuildPrefix(); var epListFile = buildPrefix + "_ep-list.tsv"; - var manifestFile = buildPrefix + "_manifest.json"; var entryPointsSubDir = Path.Combine("..", "Common", "EntryPoints"); var catalog = ModuleCatalog.CreateInstance(Env); var path = DeleteOutputPath(entryPointsSubDir, epListFile); + File.WriteAllLines(path, epListContents); + + CheckEquality(entryPointsSubDir, epListFile); + + var manifestFile = buildPrefix + "_manifest.json"; + var jPath = DeleteOutputPath(entryPointsSubDir, manifestFile); + using (var file = File.OpenWrite(jPath)) + using (var writer = new StreamWriter(file)) + using (var jw = new JsonTextWriter(writer)) + { + jw.Formatting = Formatting.Indented; + jObj.WriteTo(jw); + } + + CheckEquality(entryPointsSubDir, manifestFile); + Done(); + } + + private (IEnumerable epListContents, JObject manifest) BuildManifests() + { + Env.ComponentCatalog.RegisterAssembly(typeof(LightGbmBinaryPredictor).Assembly); + Env.ComponentCatalog.RegisterAssembly(typeof(TensorFlowTransform).Assembly); + Env.ComponentCatalog.RegisterAssembly(typeof(ImageLoaderTransform).Assembly); + Env.ComponentCatalog.RegisterAssembly(typeof(SymSgdClassificationTrainer).Assembly); + Env.ComponentCatalog.RegisterAssembly(typeof(AutoInference).Assembly); + Env.ComponentCatalog.RegisterAssembly(typeof(SaveOnnxCommand).Assembly); + + var catalog = ModuleCatalog.CreateInstance(Env); + var regex = new Regex(@"\r\n?|\n", RegexOptions.Compiled); - File.WriteAllLines(path, catalog.AllEntryPoints() + var epListContents = catalog.AllEntryPoints() .Select(x => string.Join("\t", x.Name, regex.Replace(x.Description, ""), @@ -302,39 +319,27 @@ public void EntryPointCatalog() x.InputType, x.OutputType) .Replace(Environment.NewLine, "")) - .OrderBy(x => x)); + .OrderBy(x => x); - CheckEquality(entryPointsSubDir, epListFile); - - var jObj = JsonManifestUtils.BuildAllManifests(Env, catalog); + var manifest = JsonManifestUtils.BuildAllManifests(Env, catalog); //clean up the description from the new line characters - if (jObj[FieldNames.TopEntryPoints] != null && jObj[FieldNames.TopEntryPoints] is JArray) + if (manifest[FieldNames.TopEntryPoints] != null && manifest[FieldNames.TopEntryPoints] is JArray) { - foreach (JToken entry in jObj[FieldNames.TopEntryPoints].Children()) + foreach (JToken entry in manifest[FieldNames.TopEntryPoints].Children()) if (entry[FieldNames.Desc] != null) entry[FieldNames.Desc] = regex.Replace(entry[FieldNames.Desc].ToString(), ""); } - var jPath = DeleteOutputPath(entryPointsSubDir, manifestFile); - using (var file = File.OpenWrite(jPath)) - using (var writer = new StreamWriter(file)) - using (var jw = new JsonTextWriter(writer)) - { - jw.Formatting = Formatting.Indented; - jObj.WriteTo(jw); - } - - CheckEquality(entryPointsSubDir, manifestFile); - Done(); + return (epListContents, manifest); } [Fact] public void EntryPointInputBuilderOptionals() { - var catelog = ModuleCatalog.CreateInstance(Env); + var catalog = ModuleCatalog.CreateInstance(Env); - InputBuilder ib1 = new InputBuilder(Env, typeof(LogisticRegression.Arguments), catelog); + InputBuilder ib1 = new InputBuilder(Env, typeof(LogisticRegression.Arguments), catalog); // Ensure that InputBuilder unwraps the Optional correctly. var weightType = ib1.GetFieldTypeOrNull("WeightColumn"); Assert.True(weightType.Equals(typeof(string))); @@ -1794,12 +1799,14 @@ public void EntryPointEvaluateRanking() [Fact] public void EntryPointLightGbmBinary() { + Env.ComponentCatalog.RegisterAssembly(typeof(LightGbmBinaryPredictor).Assembly); TestEntryPointRoutine("breast-cancer.txt", "Trainers.LightGbmBinaryClassifier"); } [Fact] public void EntryPointLightGbmMultiClass() { + Env.ComponentCatalog.RegisterAssembly(typeof(LightGbmBinaryPredictor).Assembly); TestEntryPointRoutine(GetDataPath(@"iris.txt"), "Trainers.LightGbmClassifier"); } @@ -3728,6 +3735,8 @@ public void EntryPointWordEmbeddings() [Fact] public void EntryPointTensorFlowTransform() { + Env.ComponentCatalog.RegisterAssembly(typeof(TensorFlowTransform).Assembly); + TestEntryPointPipelineRoutine(GetDataPath("Train-Tiny-28x28.txt"), "col=Label:R4:0 col=Placeholder:R4:1-784", new[] { "Transforms.TensorFlowScorer" }, new[] diff --git a/test/Microsoft.ML.InferenceTesting/GenerateSweepCandidatesCommand.cs b/test/Microsoft.ML.InferenceTesting/GenerateSweepCandidatesCommand.cs index b964d4dae6..534ababbc7 100644 --- a/test/Microsoft.ML.InferenceTesting/GenerateSweepCandidatesCommand.cs +++ b/test/Microsoft.ML.InferenceTesting/GenerateSweepCandidatesCommand.cs @@ -85,7 +85,7 @@ public GenerateSweepCandidatesCommand(IHostEnvironment env, Arguments args) if (!string.IsNullOrWhiteSpace(args.Sweeper)) { - var info = ComponentCatalog.GetLoadableClassInfo(args.Sweeper); + var info = env.ComponentCatalog.GetLoadableClassInfo(args.Sweeper); _host.CheckUserArg(info?.SignatureTypes[0] == typeof(SignatureSweeper), nameof(args.Sweeper), "Please specify a valid sweeper."); _sweeper = args.Sweeper; @@ -95,7 +95,7 @@ public GenerateSweepCandidatesCommand(IHostEnvironment env, Arguments args) if (!string.IsNullOrWhiteSpace(args.Mode)) { - var info = ComponentCatalog.GetLoadableClassInfo(args.Mode); + var info = env.ComponentCatalog.GetLoadableClassInfo(args.Mode); _host.CheckUserArg(info?.Type == typeof(TrainCommand) || info?.Type == typeof(TrainTestCommand) || info?.Type == typeof(CrossValidationCommand), nameof(args.Mode), "Invalid mode."); diff --git a/test/Microsoft.ML.Predictor.Tests/CmdLine/CmdLineReverseTest.cs b/test/Microsoft.ML.Predictor.Tests/CmdLine/CmdLineReverseTest.cs index 2369136f95..6d95829454 100644 --- a/test/Microsoft.ML.Predictor.Tests/CmdLine/CmdLineReverseTest.cs +++ b/test/Microsoft.ML.Predictor.Tests/CmdLine/CmdLineReverseTest.cs @@ -113,9 +113,9 @@ private class SimpleArg /// ToString is overrided by CmdParser.GetSettings which is of primary for this test /// /// - public string ToString(IExceptionContext ectx) + public string ToString(IHostEnvironment env) { - return CmdParser.GetSettings(ectx, this, new SimpleArg(), SettingsFlags.None); + return CmdParser.GetSettings(env, this, new SimpleArg(), SettingsFlags.None); } public override bool Equals(object obj) diff --git a/test/Microsoft.ML.Predictor.Tests/TestAutoInference.cs b/test/Microsoft.ML.Predictor.Tests/TestAutoInference.cs index 1b991300ea..9d0af99420 100644 --- a/test/Microsoft.ML.Predictor.Tests/TestAutoInference.cs +++ b/test/Microsoft.ML.Predictor.Tests/TestAutoInference.cs @@ -6,6 +6,7 @@ using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.EntryPoints.JsonUtils; using Microsoft.ML.Runtime.PipelineInference; +using Microsoft.ML.TestFramework; using Newtonsoft.Json.Linq; using System.Collections.Generic; using System.Linq; @@ -26,7 +27,8 @@ public TestAutoInference(ITestOutputHelper helper) [TestCategory("EntryPoints")] public void TestLearn() { - using (var env = new ConsoleEnvironment()) + using (var env = new LocalEnvironment() + .AddStandardComponents()) // AutoInference.InferPipelines uses ComponentCatalog to read text data { string pathData = GetDataPath("adult.train"); string pathDataTest = GetDataPath("adult.test"); @@ -72,7 +74,8 @@ public void TestLearn() [Fact(Skip = "Need CoreTLC specific baseline update")] public void TestTextDatasetLearn() { - using (var env = new ConsoleEnvironment()) + using (var env = new LocalEnvironment() + .AddStandardComponents()) // AutoInference uses ComponentCatalog to find all learners { string pathData = GetDataPath(@"../UnitTest/tweets_labeled_10k_test_validation.tsv"); int batchSize = 5; @@ -96,7 +99,8 @@ public void TestTextDatasetLearn() [Fact] public void TestPipelineNodeCloning() { - using (var env = new ConsoleEnvironment()) + using (var env = new LocalEnvironment() + .AddStandardComponents()) // RecipeInference.AllowedLearners uses ComponentCatalog to find all learners { var lr1 = RecipeInference .AllowedLearners(env, MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer) @@ -138,15 +142,16 @@ public void TestHyperparameterFreezing() int batchSize = 1; int numIterations = 10; int numTransformLevels = 3; - using (var env = new ConsoleEnvironment()) + using (var env = new LocalEnvironment() + .AddStandardComponents()) // AutoInference uses ComponentCatalog to find all learners { SupportedMetric metric = PipelineSweeperSupportedMetrics.GetSupportedMetric(PipelineSweeperSupportedMetrics.Metrics.Auc); // Using the simple, uniform random sampling (with replacement) brain - PipelineOptimizerBase autoMlBrain = new UniformRandomEngine(Env); + PipelineOptimizerBase autoMlBrain = new UniformRandomEngine(env); // Run initial experiments - var amls = AutoInference.InferPipelines(Env, autoMlBrain, pathData, "", out var _, numTransformLevels, batchSize, + var amls = AutoInference.InferPipelines(env, autoMlBrain, pathData, "", out var _, numTransformLevels, batchSize, metric, out var bestPipeline, numOfSampleRows, new IterationTerminator(numIterations), MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer); @@ -186,15 +191,16 @@ public void TestRegressionPipelineWithMinimizingMetric() int batchSize = 5; int numIterations = 10; int numTransformLevels = 1; - using (var env = new ConsoleEnvironment()) + using (var env = new LocalEnvironment() + .AddStandardComponents()) // AutoInference uses ComponentCatalog to find all learners { SupportedMetric metric = PipelineSweeperSupportedMetrics.GetSupportedMetric(PipelineSweeperSupportedMetrics.Metrics.AccuracyMicro); // Using the simple, uniform random sampling (with replacement) brain - PipelineOptimizerBase autoMlBrain = new UniformRandomEngine(Env); + PipelineOptimizerBase autoMlBrain = new UniformRandomEngine(env); // Run initial experiments - var amls = AutoInference.InferPipelines(Env, autoMlBrain, pathData, "", out var _, numTransformLevels, batchSize, + var amls = AutoInference.InferPipelines(env, autoMlBrain, pathData, "", out var _, numTransformLevels, batchSize, metric, out var bestPipeline, numOfSampleRows, new IterationTerminator(numIterations), MacroUtils.TrainerKinds.SignatureRegressorTrainer); @@ -220,15 +226,16 @@ public void TestLearnerConstrainingByName() int numIterations = 1; int numTransformLevels = 2; var retainedLearnerNames = new[] { $"LogisticRegressionBinaryClassifier", $"FastTreeBinaryClassifier" }; - using (var env = new ConsoleEnvironment()) + using (var env = new LocalEnvironment() + .AddStandardComponents()) // AutoInference uses ComponentCatalog to find all learners { SupportedMetric metric = PipelineSweeperSupportedMetrics.GetSupportedMetric(PipelineSweeperSupportedMetrics.Metrics.Auc); // Using the simple, uniform random sampling (with replacement) brain. - PipelineOptimizerBase autoMlBrain = new UniformRandomEngine(Env); + PipelineOptimizerBase autoMlBrain = new UniformRandomEngine(env); // Run initial experiment. - var amls = AutoInference.InferPipelines(Env, autoMlBrain, pathData, "", out var _, + var amls = AutoInference.InferPipelines(env, autoMlBrain, pathData, "", out var _, numTransformLevels, batchSize, metric, out var _, numOfSampleRows, new IterationTerminator(numIterations), MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer); diff --git a/test/Microsoft.ML.Predictor.Tests/TestPipelineSweeper.cs b/test/Microsoft.ML.Predictor.Tests/TestPipelineSweeper.cs index 5f65ca4c28..cd86e83414 100644 --- a/test/Microsoft.ML.Predictor.Tests/TestPipelineSweeper.cs +++ b/test/Microsoft.ML.Predictor.Tests/TestPipelineSweeper.cs @@ -22,6 +22,12 @@ public TestPipelineSweeper(ITestOutputHelper helper) { } + protected override void InitializeCore() + { + base.InitializeCore(); + Env.ComponentCatalog.RegisterAssembly(typeof(AutoInference).Assembly); + } + [Fact] public void PipelineSweeperBasic() { diff --git a/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs b/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs index 4de917921b..fdf2f50168 100644 --- a/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs +++ b/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs @@ -7,7 +7,6 @@ using System; using System.Collections.Generic; using System.IO; -using Microsoft.ML.Runtime.CommandLine; namespace Microsoft.ML.Runtime.RunTests { @@ -15,6 +14,8 @@ namespace Microsoft.ML.Runtime.RunTests using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.FastTree; using Microsoft.ML.Runtime.FastTree.Internal; + using Microsoft.ML.Runtime.LightGBM; + using Microsoft.ML.Runtime.SymSgd; using Microsoft.ML.TestFramework; using System.Linq; using System.Runtime.InteropServices; @@ -27,6 +28,20 @@ namespace Microsoft.ML.Runtime.RunTests /// public sealed partial class TestPredictors : BaseTestPredictors { + protected override void InitializeCore() + { + base.InitializeCore(); + InitializeEnvironment(Env); + } + + protected override void InitializeEnvironment(IHostEnvironment environment) + { + base.InitializeEnvironment(environment); + + environment.ComponentCatalog.RegisterAssembly(typeof(LightGbmBinaryPredictor).Assembly); + environment.ComponentCatalog.RegisterAssembly(typeof(SymSgdClassificationTrainer).Assembly); + } + /// /// Get a list of datasets for binary classifier base test. /// diff --git a/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs b/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs index f3c21feabf..e4ce5fac50 100644 --- a/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs +++ b/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs @@ -108,7 +108,8 @@ public void Init() string logPath = Path.Combine(logDir, FullTestName + LogSuffix); LogWriter = OpenWriter(logPath); _passed = true; - Env = new ConsoleEnvironment(42, outWriter: LogWriter, errWriter: LogWriter); + Env = new ConsoleEnvironment(42, outWriter: LogWriter, errWriter: LogWriter) + .AddStandardComponents(); InitializeCore(); } diff --git a/test/Microsoft.ML.TestFramework/BaseTestPredictorsMaml.cs b/test/Microsoft.ML.TestFramework/BaseTestPredictorsMaml.cs index ddd9dfb8da..1b9eed9fe5 100644 --- a/test/Microsoft.ML.TestFramework/BaseTestPredictorsMaml.cs +++ b/test/Microsoft.ML.TestFramework/BaseTestPredictorsMaml.cs @@ -295,7 +295,7 @@ protected void RunResultProcessorTest(string[] dataFiles, string outPath, string if (extraArgs != null) args.AddRange(extraArgs); - ResultProcessor.Main(args.ToArray()); + ResultProcessor.Main(Env, args.ToArray()); } private static string GetNamePrefix(string testType, PredictorAndArgs predictor, TestDataset dataset, string extraTag = "") diff --git a/test/Microsoft.ML.TestFramework/DataPipe/Parquet.cs b/test/Microsoft.ML.TestFramework/DataPipe/Parquet.cs index 26f265bf9e..b3b68c5a8a 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/Parquet.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/Parquet.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Runtime.Data; using System; using Xunit; @@ -9,6 +10,11 @@ namespace Microsoft.ML.Runtime.RunTests { public sealed partial class TestParquet : TestDataPipeBase { + protected override void InitializeCore() + { + base.InitializeCore(); + Env.ComponentCatalog.RegisterAssembly(typeof(ParquetLoader).Assembly); + } [Fact] public void TestParquetPrimitiveDataTypes() diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs index 0ec0e35341..e1b4986f06 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs @@ -311,7 +311,7 @@ protected IDataLoader TestCore(string pathData, bool keepHidden, string[] argsPi protected IDataLoader CreatePipeDataLoader(IHostEnvironment env, string pathData, string[] argsPipe, out MultiFileSource files) { - VerifyArgParsing(argsPipe); + VerifyArgParsing(env, argsPipe); // Default to breast-cancer.txt. if (string.IsNullOrEmpty(pathData)) @@ -350,7 +350,7 @@ protected void TestApplyTransformsToData(IHostEnvironment env, IDataLoader pipe, Failed(); } - protected void VerifyArgParsing(string[] strs) + protected void VerifyArgParsing(IHostEnvironment env, string[] strs) { string str = CmdParser.CombineSettings(strs); var args = new CompositeDataLoader.Arguments(); @@ -361,18 +361,18 @@ protected void VerifyArgParsing(string[] strs) } // For the loader and each transform, verify that custom unparsing is correct. - VerifyCustArgs(args.Loader); + VerifyCustArgs(env, args.Loader); foreach (var kvp in args.Transform) - VerifyCustArgs(kvp.Value); + VerifyCustArgs(env, kvp.Value); } - protected void VerifyCustArgs(IComponentFactory factory) + protected void VerifyCustArgs(IHostEnvironment env, IComponentFactory factory) where TRes : class { if (factory is ICommandLineComponentFactory commandLineFactory) { var str = commandLineFactory.GetSettingsString(); - var info = ComponentCatalog.GetLoadableClassInfo(commandLineFactory.Name, commandLineFactory.SignatureType); + var info = env.ComponentCatalog.GetLoadableClassInfo(commandLineFactory.Name, commandLineFactory.SignatureType); Assert.NotNull(info); var def = info.CreateArguments(); diff --git a/test/Microsoft.ML.TestFramework/EnvironmentExtensions.cs b/test/Microsoft.ML.TestFramework/EnvironmentExtensions.cs new file mode 100644 index 0000000000..986e82a9c4 --- /dev/null +++ b/test/Microsoft.ML.TestFramework/EnvironmentExtensions.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Ensemble; +using Microsoft.ML.Runtime.FastTree; +using Microsoft.ML.Runtime.KMeans; +using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.Runtime.PCA; + +namespace Microsoft.ML.TestFramework +{ + public static class EnvironmentExtensions + { + public static TEnvironment AddStandardComponents(this TEnvironment env) + where TEnvironment : IHostEnvironment + { + env.ComponentCatalog.RegisterAssembly(typeof(TextLoader).Assembly); // ML.Data + env.ComponentCatalog.RegisterAssembly(typeof(LinearPredictor).Assembly); // ML.StandardLearners + env.ComponentCatalog.RegisterAssembly(typeof(CategoricalTransform).Assembly); // ML.Transforms + env.ComponentCatalog.RegisterAssembly(typeof(FastTreeBinaryPredictor).Assembly); // ML.FastTree + env.ComponentCatalog.RegisterAssembly(typeof(EnsemblePredictor).Assembly); // ML.Ensemble + env.ComponentCatalog.RegisterAssembly(typeof(KMeansPredictor).Assembly); // ML.KMeansClustering + env.ComponentCatalog.RegisterAssembly(typeof(PcaPredictor).Assembly); // ML.PCA + env.ComponentCatalog.RegisterAssembly(typeof(Experiment).Assembly); // ML.Legacy + return env; + } + } +} diff --git a/test/Microsoft.ML.TestFramework/Microsoft.ML.TestFramework.csproj b/test/Microsoft.ML.TestFramework/Microsoft.ML.TestFramework.csproj index ad50df8042..b191c08a91 100644 --- a/test/Microsoft.ML.TestFramework/Microsoft.ML.TestFramework.csproj +++ b/test/Microsoft.ML.TestFramework/Microsoft.ML.TestFramework.csproj @@ -7,11 +7,15 @@ + + + + diff --git a/test/Microsoft.ML.TestFramework/TestCommandBase.cs b/test/Microsoft.ML.TestFramework/TestCommandBase.cs index a0c72a15ab..10b0e9552c 100644 --- a/test/Microsoft.ML.TestFramework/TestCommandBase.cs +++ b/test/Microsoft.ML.TestFramework/TestCommandBase.cs @@ -14,6 +14,7 @@ using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.Tools; +using Microsoft.ML.TestFramework; using Xunit; using Xunit.Abstractions; @@ -270,6 +271,11 @@ public virtual OutputPath MetricsPath() } } + protected virtual void InitializeEnvironment(IHostEnvironment environment) + { + environment.AddStandardComponents(); + } + /// /// Runs a command with some arguments. Note that the input /// objects are used for comparison only. @@ -283,6 +289,8 @@ protected bool TestCore(RunContextBase ctx, string cmdName, string args, params using (var newWriter = OpenWriter(outputPath.Path)) using (var env = new ConsoleEnvironment(42, outWriter: newWriter, errWriter: newWriter)) { + InitializeEnvironment(env); + int res; res = MainForTest(env, newWriter, string.Format("{0} {1}", cmdName, args), ctx.BaselineProgress); if (res != 0) diff --git a/test/Microsoft.ML.Tests/ImagesTests.cs b/test/Microsoft.ML.Tests/ImagesTests.cs index adc229e328..c5aabae99d 100644 --- a/test/Microsoft.ML.Tests/ImagesTests.cs +++ b/test/Microsoft.ML.Tests/ImagesTests.cs @@ -29,8 +29,21 @@ public void TestEstimatorChain() { var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); - var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); - var invalidData = env.CreateLoader("Text{col=ImagePath:R4:0}", new MultiFileSource(dataFile)); + var data = TextLoader.Create(env, new TextLoader.Arguments() + { + Column = new[] + { + new TextLoader.Column("ImagePath", DataKind.TX, 0), + new TextLoader.Column("Name", DataKind.TX, 1), + } + }, new MultiFileSource(dataFile)); + var invalidData = TextLoader.Create(env, new TextLoader.Arguments() + { + Column = new[] + { + new TextLoader.Column("ImagePath", DataKind.R4, 0), + } + }, new MultiFileSource(dataFile)); var pipe = new ImageLoaderEstimator(env, imageFolder, ("ImagePath", "ImageReal")) .Append(new ImageResizerEstimator(env, "ImageReal", "ImageReal", 100, 100)) @@ -49,7 +62,14 @@ public void TestEstimatorSaveLoad() { var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); - var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); + var data = TextLoader.Create(env, new TextLoader.Arguments() + { + Column = new[] + { + new TextLoader.Column("ImagePath", DataKind.TX, 0), + new TextLoader.Column("Name", DataKind.TX, 1), + } + }, new MultiFileSource(dataFile)); var pipe = new ImageLoaderEstimator(env, imageFolder, ("ImagePath", "ImageReal")) .Append(new ImageResizerEstimator(env, "ImageReal", "ImageReal", 100, 100)) @@ -82,7 +102,14 @@ public void TestSaveImages() { var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); - var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); + var data = TextLoader.Create(env, new TextLoader.Arguments() + { + Column = new[] + { + new TextLoader.Column("ImagePath", DataKind.TX, 0), + new TextLoader.Column("Name", DataKind.TX, 1), + } + }, new MultiFileSource(dataFile)); var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() { Column = new ImageLoaderTransform.Column[1] @@ -129,7 +156,14 @@ public void TestGreyscaleTransformImages() var imageWidth = 100; var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); - var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); + var data = TextLoader.Create(env, new TextLoader.Arguments() + { + Column = new[] + { + new TextLoader.Column("ImagePath", DataKind.TX, 0), + new TextLoader.Column("Name", DataKind.TX, 1), + } + }, new MultiFileSource(dataFile)); var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() { Column = new ImageLoaderTransform.Column[1] @@ -192,7 +226,14 @@ public void TestBackAndForthConversionWithAlphaInterleave() var imageWidth = 130; var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); - var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); + var data = TextLoader.Create(env, new TextLoader.Arguments() + { + Column = new[] + { + new TextLoader.Column("ImagePath", DataKind.TX, 0), + new TextLoader.Column("Name", DataKind.TX, 1), + } + }, new MultiFileSource(dataFile)); var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() { Column = new ImageLoaderTransform.Column[1] @@ -275,7 +316,14 @@ public void TestBackAndForthConversionWithoutAlphaInterleave() var imageWidth = 130; var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); - var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); + var data = TextLoader.Create(env, new TextLoader.Arguments() + { + Column = new[] + { + new TextLoader.Column("ImagePath", DataKind.TX, 0), + new TextLoader.Column("Name", DataKind.TX, 1), + } + }, new MultiFileSource(dataFile)); var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() { Column = new ImageLoaderTransform.Column[1] @@ -358,7 +406,14 @@ public void TestBackAndForthConversionWithAlphaNoInterleave() var imageWidth = 130; var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); - var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); + var data = TextLoader.Create(env, new TextLoader.Arguments() + { + Column = new[] + { + new TextLoader.Column("ImagePath", DataKind.TX, 0), + new TextLoader.Column("Name", DataKind.TX, 1), + } + }, new MultiFileSource(dataFile)); var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() { Column = new ImageLoaderTransform.Column[1] @@ -441,7 +496,14 @@ public void TestBackAndForthConversionWithoutAlphaNoInterleave() var imageWidth = 130; var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); - var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); + var data = TextLoader.Create(env, new TextLoader.Arguments() + { + Column = new[] + { + new TextLoader.Column("ImagePath", DataKind.TX, 0), + new TextLoader.Column("Name", DataKind.TX, 1), + } + }, new MultiFileSource(dataFile)); var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() { Column = new ImageLoaderTransform.Column[1] @@ -524,7 +586,14 @@ public void TestBackAndForthConversionWithAlphaInterleaveNoOffset() var imageWidth = 130; var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); - var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); + var data = TextLoader.Create(env, new TextLoader.Arguments() + { + Column = new[] + { + new TextLoader.Column("ImagePath", DataKind.TX, 0), + new TextLoader.Column("Name", DataKind.TX, 1), + } + }, new MultiFileSource(dataFile)); var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() { Column = new ImageLoaderTransform.Column[1] @@ -603,7 +672,14 @@ public void TestBackAndForthConversionWithoutAlphaInterleaveNoOffset() var imageWidth = 130; var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); - var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); + var data = TextLoader.Create(env, new TextLoader.Arguments() + { + Column = new[] + { + new TextLoader.Column("ImagePath", DataKind.TX, 0), + new TextLoader.Column("Name", DataKind.TX, 1), + } + }, new MultiFileSource(dataFile)); var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() { Column = new ImageLoaderTransform.Column[1] @@ -682,7 +758,14 @@ public void TestBackAndForthConversionWithAlphaNoInterleaveNoOffset() var imageWidth = 130; var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); - var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); + var data = TextLoader.Create(env, new TextLoader.Arguments() + { + Column = new[] + { + new TextLoader.Column("ImagePath", DataKind.TX, 0), + new TextLoader.Column("Name", DataKind.TX, 1), + } + }, new MultiFileSource(dataFile)); var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() { Column = new ImageLoaderTransform.Column[1] @@ -761,7 +844,14 @@ public void TestBackAndForthConversionWithoutAlphaNoInterleaveNoOffset() var imageWidth = 130; var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); - var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); + var data = TextLoader.Create(env, new TextLoader.Arguments() + { + Column = new[] + { + new TextLoader.Column("ImagePath", DataKind.TX, 0), + new TextLoader.Column("Name", DataKind.TX, 1), + } + }, new MultiFileSource(dataFile)); var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() { Column = new ImageLoaderTransform.Column[1] diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/DecomposableTrainAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/DecomposableTrainAndPredict.cs index 0d20c937c3..68857999e8 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/DecomposableTrainAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/DecomposableTrainAndPredict.cs @@ -6,6 +6,7 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Learners; using Microsoft.ML.Runtime.RunTests; +using Microsoft.ML.TestFramework; using System.Linq; using Xunit; @@ -25,7 +26,8 @@ public partial class ApiScenariosTests [Fact] void DecomposableTrainAndPredict() { - using (var env = new LocalEnvironment()) + using (var env = new LocalEnvironment() + .AddStandardComponents()) // ScoreUtils.GetScorer requires scorers to be registered in the ComponentCatalog { var loader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.irisData.trainFilename))); var term = TermTransform.Create(env, loader, "Label"); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs index 6e937ce238..68d7994842 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs @@ -6,6 +6,7 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Learners; using Microsoft.ML.Runtime.RunTests; +using Microsoft.ML.TestFramework; using System.Linq; using Xunit; @@ -25,7 +26,8 @@ public partial class ApiScenariosTests void New_DecomposableTrainAndPredict() { var dataPath = GetDataPath(TestDatasets.irisData.trainFilename); - using (var env = new LocalEnvironment()) + using (var env = new LocalEnvironment() + .AddStandardComponents()) // ScoreUtils.GetScorer requires scorers to be registered in the ComponentCatalog { var loader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath)); var term = TermTransform.Create(env, loader, "Label"); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs index c5a6a40703..aa425e3f4e 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs @@ -52,7 +52,8 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature); + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(LoaderWrapper).Assembly.FullName); } public LoaderWrapper(IHostEnvironment env, ModelLoadContext ctx) diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Extensibility.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Extensibility.cs index f77dac19ba..15e0824c7d 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Extensibility.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Extensibility.cs @@ -2,6 +2,7 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Learners; using Microsoft.ML.Runtime.RunTests; +using Microsoft.ML.TestFramework; using System; using System.Linq; using Xunit; @@ -19,7 +20,8 @@ public partial class ApiScenariosTests [Fact] void Extensibility() { - using (var env = new LocalEnvironment()) + using (var env = new LocalEnvironment() + .AddStandardComponents()) // ScoreUtils.GetScorer requires scorers to be registered in the ComponentCatalog { var loader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.irisData.trainFilename))); Action action = (i, j) => diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index 71f5f95f33..609ccf9c1c 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -420,7 +420,14 @@ public void TensorFlowTransformCifar() var imageWidth = 32; var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); - var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); + var data = TextLoader.Create(env, new TextLoader.Arguments() + { + Column = new[] + { + new TextLoader.Column("ImagePath", DataKind.TX, 0), + new TextLoader.Column("Name", DataKind.TX, 1), + } + }, new MultiFileSource(dataFile)); var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() { Column = new ImageLoaderTransform.Column[1] @@ -474,8 +481,14 @@ public void TensorFlowTransformCifarInvalidShape() var imageWidth = 28; var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); - var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); - + var data = TextLoader.Create(env, new TextLoader.Arguments() + { + Column = new[] + { + new TextLoader.Column("ImagePath", DataKind.TX, 0), + new TextLoader.Column("Name", DataKind.TX, 1), + } + }, new MultiFileSource(dataFile)); var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() { Column = new ImageLoaderTransform.Column[1] diff --git a/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs b/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs index d102ea758c..6ea74d8dbb 100644 --- a/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs +++ b/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs @@ -134,7 +134,7 @@ void TestCommandLine() { using (var env = new ConsoleEnvironment()) { - Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=a:R4:0-3 col=b:R4:0-3} xf=TFTransform{inputs=a inputs=b outputs=c model={model_matmul/frozen_saved_model.pb}} in=f:\2.txt" }), (int)0); + Assert.Equal(0, Maml.Main(new[] { @"showschema loader=Text{col=a:R4:0-3 col=b:R4:0-3} xf=TFTransform{inputs=a inputs=b outputs=c model={model_matmul/frozen_saved_model.pb}} in=f:\2.txt" })); } } diff --git a/test/Microsoft.ML.Tests/TermEstimatorTests.cs b/test/Microsoft.ML.Tests/TermEstimatorTests.cs index c69c91d23c..93c827ff30 100644 --- a/test/Microsoft.ML.Tests/TermEstimatorTests.cs +++ b/test/Microsoft.ML.Tests/TermEstimatorTests.cs @@ -154,7 +154,7 @@ void TestCommandLine() { using (var env = new ConsoleEnvironment()) { - Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Term{col=B:A} in=f:\2.txt" }), (int)0); + Assert.Equal(0, Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Term{col=B:A} in=f:\2.txt" })); } } diff --git a/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs b/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs index b995f7c984..e4746ccffd 100644 --- a/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs @@ -298,7 +298,7 @@ private void ValidateBagMetadata(IDataView result) [Fact] public void TestCommandLine() { - Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Cat{col=B:A} in=f:\2.txt" }), (int)0); + Assert.Equal(0, Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Cat{col=B:A} in=f:\2.txt" })); } [Fact] diff --git a/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs b/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs index 56036bbc02..c5a9620c2b 100644 --- a/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs +++ b/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs @@ -151,7 +151,7 @@ private void ValidateMetadata(IDataView result) [Fact] public void TestCommandLine() { - Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Term{col=B:A} xf=KeyToBinary{col=C:B} in=f:\2.txt" }), (int)0); + Assert.Equal(0, Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Term{col=B:A} xf=KeyToBinary{col=C:B} in=f:\2.txt" })); } [Fact] diff --git a/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs b/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs index faf8b47e79..a460d9482b 100644 --- a/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs @@ -216,7 +216,7 @@ private void ValidateMetadata(IDataView result) [Fact] public void TestCommandLine() { - Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Term{col=B:A} xf=KeyToVector{col=C:B col={name=D source=B bag+}} in=f:\2.txt" }), (int)0); + Assert.Equal(0, Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Term{col=B:A} xf=KeyToVector{col=C:B col={name=D source=B bag+}} in=f:\2.txt" })); } [Fact]