diff --git a/src/Common/AssemblyLoadingUtils.cs b/src/Common/AssemblyLoadingUtils.cs
new file mode 100644
index 0000000000..64477aa7e0
--- /dev/null
+++ b/src/Common/AssemblyLoadingUtils.cs
@@ -0,0 +1,290 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime.Internal.Utilities;
+using System;
+using System.IO;
+using System.IO.Compression;
+using System.Reflection;
+
+namespace Microsoft.ML.Runtime
+{
+ internal static class AssemblyLoadingUtils
+ {
+ ///
+ /// Make sure the given assemblies are loaded and that their loadable classes have been catalogued.
+ ///
+ public static void LoadAndRegister(IHostEnvironment env, string[] assemblies)
+ {
+ Contracts.AssertValue(env);
+
+ if (Utils.Size(assemblies) > 0)
+ {
+ foreach (string path in assemblies)
+ {
+ Exception ex = null;
+ try
+ {
+ // REVIEW: Will LoadFrom ever return null?
+ Contracts.CheckNonEmpty(path, nameof(path));
+ var assem = LoadAssembly(env, path);
+ if (assem != null)
+ continue;
+ }
+ catch (Exception e)
+ {
+ ex = e;
+ }
+
+ // If it is a zip file, load it that way.
+ ZipArchive zip;
+ try
+ {
+ zip = ZipFile.OpenRead(path);
+ }
+ catch (Exception e)
+ {
+ // Couldn't load as an assembly and not a zip, so warn the user.
+ ex = ex ?? e;
+ Console.Error.WriteLine("Warning: Could not load '{0}': {1}", path, ex.Message);
+ continue;
+ }
+
+ string dir;
+ try
+ {
+ dir = CreateTempDirectory();
+ }
+ catch (Exception e)
+ {
+ throw Contracts.ExceptIO(e, "Creating temp directory for extra assembly zip extraction failed: '{0}'", path);
+ }
+
+ try
+ {
+ zip.ExtractToDirectory(dir);
+ }
+ catch (Exception e)
+ {
+ throw Contracts.ExceptIO(e, "Extracting extra assembly zip failed: '{0}'", path);
+ }
+
+ LoadAssembliesInDir(env, dir, false);
+ }
+ }
+ }
+
+ public static IDisposable CreateAssemblyRegistrar(IHostEnvironment env, string loadAssembliesPath = null)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValueOrNull(loadAssembliesPath);
+
+ return new AssemblyRegistrar(env, loadAssembliesPath);
+ }
+
+ public static void RegisterCurrentLoadedAssemblies(IHostEnvironment env)
+ {
+ Contracts.CheckValue(env, nameof(env));
+
+ foreach (Assembly a in AppDomain.CurrentDomain.GetAssemblies())
+ {
+ TryRegisterAssembly(env.ComponentCatalog, a);
+ }
+ }
+
+ private static string CreateTempDirectory()
+ {
+ string dir = GetTempPath();
+ Directory.CreateDirectory(dir);
+ return dir;
+ }
+
+ private static string GetTempPath()
+ {
+ Guid guid = Guid.NewGuid();
+ return Path.GetFullPath(Path.Combine(Path.GetTempPath(), "MLNET_" + guid.ToString()));
+ }
+
+ private static readonly string[] _filePrefixesToAvoid = new string[] {
+ "api-ms-win",
+ "clr",
+ "coreclr",
+ "dbgshim",
+ "ext-ms-win",
+ "microsoft.bond.",
+ "microsoft.cosmos.",
+ "microsoft.csharp",
+ "microsoft.data.",
+ "microsoft.hpc.",
+ "microsoft.live.",
+ "microsoft.platformbuilder.",
+ "microsoft.visualbasic",
+ "microsoft.visualstudio.",
+ "microsoft.win32",
+ "microsoft.windowsapicodepack.",
+ "microsoft.windowsazure.",
+ "mscor",
+ "msvc",
+ "petzold.",
+ "roslyn.",
+ "sho",
+ "sni",
+ "sqm",
+ "system.",
+ "zlib",
+ };
+
+ private static bool ShouldSkipPath(string path)
+ {
+ string name = Path.GetFileName(path).ToLowerInvariant();
+ switch (name)
+ {
+ case "cqo.dll":
+ case "fasttreenative.dll":
+ case "libiomp5md.dll":
+ case "libvw.dll":
+ case "matrixinterf.dll":
+ case "microsoft.ml.neuralnetworks.gpucuda.dll":
+ case "mklimports.dll":
+ case "microsoft.research.controls.decisiontrees.dll":
+ case "microsoft.ml.neuralnetworks.sse.dll":
+ case "neuraltreeevaluator.dll":
+ case "optimizationbuilderdotnet.dll":
+ case "parallelcommunicator.dll":
+ case "microsoft.ml.runtime.runtests.dll":
+ case "scopecompiler.dll":
+ case "tbb.dll":
+ case "internallearnscope.dll":
+ case "unmanagedlib.dll":
+ case "vcclient.dll":
+ case "libxgboost.dll":
+ case "zedgraph.dll":
+ case "__scopecodegen__.dll":
+ case "cosmosClientApi.dll":
+ return true;
+ }
+
+ foreach (var s in _filePrefixesToAvoid)
+ {
+ if (name.StartsWith(s, StringComparison.OrdinalIgnoreCase))
+ return true;
+ }
+
+ return false;
+ }
+
+ private static void LoadAssembliesInDir(IHostEnvironment env, string dir, bool filter)
+ {
+ if (!Directory.Exists(dir))
+ return;
+
+ using (var ch = env.Start("LoadAssembliesInDir"))
+ {
+ // Load all dlls in the given directory.
+ var paths = Directory.EnumerateFiles(dir, "*.dll");
+ foreach (string path in paths)
+ {
+ if (filter && ShouldSkipPath(path))
+ {
+ ch.Info($"Skipping assembly '{path}' because its name was filtered.");
+ continue;
+ }
+
+ LoadAssembly(env, path);
+ }
+ }
+ }
+
+ ///
+ /// Given an assembly path, load the assembly and register it with the ComponentCatalog.
+ ///
+ private static Assembly LoadAssembly(IHostEnvironment env, string path)
+ {
+ Assembly assembly = null;
+ try
+ {
+ assembly = Assembly.LoadFrom(path);
+ }
+ catch (Exception e)
+ {
+ using (var ch = env.Start("LoadAssembly"))
+ {
+ ch.Error("Could not load assembly {0}:\n{1}", path, e.ToString());
+ }
+ return null;
+ }
+
+ if (assembly != null)
+ {
+ TryRegisterAssembly(env.ComponentCatalog, assembly);
+ }
+
+ return assembly;
+ }
+
+ ///
+ /// Checks whether references the assembly containing LoadableClassAttributeBase,
+ /// and therefore can contain components.
+ ///
+ private static bool CanContainComponents(Assembly assembly)
+ {
+ var targetFullName = typeof(LoadableClassAttributeBase).Assembly.GetName().FullName;
+
+ bool found = false;
+ foreach (var name in assembly.GetReferencedAssemblies())
+ {
+ if (name.FullName == targetFullName)
+ {
+ found = true;
+ break;
+ }
+ }
+
+ return found;
+ }
+
+ private static void TryRegisterAssembly(ComponentCatalog catalog, Assembly assembly)
+ {
+ // Don't try to index dynamic generated assembly
+ if (assembly.IsDynamic)
+ return;
+
+ if (!CanContainComponents(assembly))
+ return;
+
+ catalog.RegisterAssembly(assembly);
+ }
+
+ private sealed class AssemblyRegistrar : IDisposable
+ {
+ private readonly IHostEnvironment _env;
+
+ public AssemblyRegistrar(IHostEnvironment env, string path)
+ {
+ _env = env;
+
+ RegisterCurrentLoadedAssemblies(_env);
+
+ if (!string.IsNullOrEmpty(path))
+ {
+ LoadAssembliesInDir(_env, path, true);
+ path = Path.Combine(path, "AutoLoad");
+ LoadAssembliesInDir(_env, path, true);
+ }
+
+ AppDomain.CurrentDomain.AssemblyLoad += CurrentDomainAssemblyLoad;
+ }
+
+ public void Dispose()
+ {
+ AppDomain.CurrentDomain.AssemblyLoad -= CurrentDomainAssemblyLoad;
+ }
+
+ private void CurrentDomainAssemblyLoad(object sender, AssemblyLoadEventArgs args)
+ {
+ TryRegisterAssembly(_env.ComponentCatalog, args.LoadedAssembly);
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Api/ComponentCreation.cs b/src/Microsoft.ML.Api/ComponentCreation.cs
index 1694a1cf4f..f34b25235c 100644
--- a/src/Microsoft.ML.Api/ComponentCreation.cs
+++ b/src/Microsoft.ML.Api/ComponentCreation.cs
@@ -439,7 +439,7 @@ private static TRes CreateCore(IHostEnvironment env, TArgs ar
{
env.CheckValue(args, nameof(args));
- var classes = ComponentCatalog.FindLoadableClasses();
+ var classes = env.ComponentCatalog.FindLoadableClasses();
if (classes.Length == 0)
throw env.Except("Couldn't find a {0} class that accepts {1} as arguments.", typeof(TRes).Name, typeof(TArgs).FullName);
if (classes.Length > 1)
diff --git a/src/Microsoft.ML.Api/SerializableLambdaTransform.cs b/src/Microsoft.ML.Api/SerializableLambdaTransform.cs
index 7de6e522d8..38bd58e76d 100644
--- a/src/Microsoft.ML.Api/SerializableLambdaTransform.cs
+++ b/src/Microsoft.ML.Api/SerializableLambdaTransform.cs
@@ -29,7 +29,8 @@ public static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001,
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(SerializableLambdaTransform).Assembly.FullName);
}
public const string LoaderSignature = "UserLambdaMapTransform";
diff --git a/src/Microsoft.ML.Core/CommandLine/CmdParser.cs b/src/Microsoft.ML.Core/CommandLine/CmdParser.cs
index d5a204dd45..a25f8cecf6 100644
--- a/src/Microsoft.ML.Core/CommandLine/CmdParser.cs
+++ b/src/Microsoft.ML.Core/CommandLine/CmdParser.cs
@@ -414,7 +414,7 @@ public static string CombineSettings(string[] settings)
}
// REVIEW: Add a method for cloning arguments, instead of going to text and back.
- public static string GetSettings(IExceptionContext ectx, object values, object defaults, SettingsFlags flags = SettingsFlags.Default)
+ public static string GetSettings(IHostEnvironment env, object values, object defaults, SettingsFlags flags = SettingsFlags.Default)
{
Type t1 = values.GetType();
Type t2 = defaults.GetType();
@@ -430,7 +430,7 @@ public static string GetSettings(IExceptionContext ectx, object values, object d
return null;
var info = GetArgumentInfo(t, defaults);
- return GetSettingsCore(ectx, info, values, flags);
+ return GetSettingsCore(env, info, values, flags);
}
public static IEnumerable> GetSettingPairs(IHostEnvironment env, object values, object defaults, SettingsFlags flags = SettingsFlags.None)
@@ -862,20 +862,20 @@ private bool Parse(ArgumentInfo info, string[] strs, object destination)
return !hadError;
}
- private static string GetSettingsCore(IExceptionContext ectx, ArgumentInfo info, object values, SettingsFlags flags)
+ private static string GetSettingsCore(IHostEnvironment env, ArgumentInfo info, object values, SettingsFlags flags)
{
StringBuilder sb = new StringBuilder();
if (info.ArgDef != null)
{
var val = info.ArgDef.GetValue(values);
- info.ArgDef.AppendSetting(ectx, sb, val, flags);
+ info.ArgDef.AppendSetting(env, sb, val, flags);
}
foreach (Argument arg in info.Args)
{
var val = arg.GetValue(values);
- arg.AppendSetting(ectx, sb, val, flags);
+ arg.AppendSetting(env, sb, val, flags);
}
return sb.ToString();
@@ -886,7 +886,7 @@ private static string GetSettingsCore(IExceptionContext ectx, ArgumentInfo info,
/// It deals with custom "unparse" functionality, as well as quoting. It also appends to a StringBuilder
/// instead of returning a string.
///
- private static void AppendCustomItem(IExceptionContext ectx, ArgumentInfo info, object values, SettingsFlags flags, StringBuilder sb)
+ private static void AppendCustomItem(IHostEnvironment env, ArgumentInfo info, object values, SettingsFlags flags, StringBuilder sb)
{
int ich = sb.Length;
// We always call unparse, even when NoUnparse is specified, since Unparse can "cleanse", which
@@ -902,13 +902,13 @@ private static void AppendCustomItem(IExceptionContext ectx, ArgumentInfo info,
if (info.ArgDef != null)
{
var val = info.ArgDef.GetValue(values);
- info.ArgDef.AppendSetting(ectx, sb, val, flags);
+ info.ArgDef.AppendSetting(env, sb, val, flags);
}
foreach (Argument arg in info.Args)
{
var val = arg.GetValue(values);
- arg.AppendSetting(ectx, sb, val, flags);
+ arg.AppendSetting(env, sb, val, flags);
}
string str = sb.ToString(ich, sb.Length - ich);
@@ -916,14 +916,14 @@ private static void AppendCustomItem(IExceptionContext ectx, ArgumentInfo info,
CmdQuoter.QuoteValue(str, sb, force: true);
}
- private IEnumerable> GetSettingPairsCore(IExceptionContext ectx, ArgumentInfo info, object values, SettingsFlags flags)
+ private IEnumerable> GetSettingPairsCore(IHostEnvironment env, ArgumentInfo info, object values, SettingsFlags flags)
{
StringBuilder buffer = new StringBuilder();
foreach (Argument arg in info.Args)
{
string key = arg.GetKey(flags);
object value = arg.GetValue(values);
- foreach (string val in arg.GetSettingStrings(ectx, value, buffer))
+ foreach (string val in arg.GetSettingStrings(env, value, buffer))
yield return new KeyValuePair(key, val);
}
}
@@ -943,13 +943,13 @@ public ArgumentHelpStrings(string syntax, string help)
///
/// A user friendly usage string describing the command line argument syntax.
///
- private string GetUsageString(IExceptionContext ectx, ArgumentInfo info, bool showRsp = true, int? columns = null)
+ private string GetUsageString(IHostEnvironment env, ArgumentInfo info, bool showRsp = true, int? columns = null)
{
int screenWidth = columns ?? Console.BufferWidth;
if (screenWidth <= 0)
screenWidth = 80;
- ArgumentHelpStrings[] strings = GetAllHelpStrings(ectx, info, showRsp);
+ ArgumentHelpStrings[] strings = GetAllHelpStrings(env, info, showRsp);
int maxParamLen = 0;
foreach (ArgumentHelpStrings helpString in strings)
@@ -1039,17 +1039,17 @@ private static void AddNewLine(string newLine, StringBuilder builder, ref int cu
currentColumn = 0;
}
- private static ArgumentHelpStrings[] GetAllHelpStrings(IExceptionContext ectx, ArgumentInfo info, bool showRsp)
+ private ArgumentHelpStrings[] GetAllHelpStrings(IHostEnvironment env, ArgumentInfo info, bool showRsp)
{
List strings = new List();
if (info.ArgDef != null)
- strings.Add(GetHelpStrings(ectx, info.ArgDef));
+ strings.Add(GetHelpStrings(env, info.ArgDef));
foreach (Argument arg in info.Args)
{
if (!arg.IsHidden)
- strings.Add(GetHelpStrings(ectx, arg));
+ strings.Add(GetHelpStrings(env, arg));
}
if (showRsp)
@@ -1058,9 +1058,9 @@ private static ArgumentHelpStrings[] GetAllHelpStrings(IExceptionContext ectx, A
return strings.ToArray();
}
- private static ArgumentHelpStrings GetHelpStrings(IExceptionContext ectx, Argument arg)
+ private ArgumentHelpStrings GetHelpStrings(IHostEnvironment env, Argument arg)
{
- return new ArgumentHelpStrings(arg.GetSyntaxHelp(), arg.GetFullHelpText(ectx));
+ return new ArgumentHelpStrings(arg.GetSyntaxHelp(), arg.GetFullHelpText(env, this));
}
private bool LexFileArguments(string fileName, out string[] arguments)
@@ -1994,7 +1994,7 @@ private bool ParseValue(CmdParser owner, string data, out object value)
return false;
}
- private void AppendHelpValue(IExceptionContext ectx, StringBuilder builder, object value)
+ private void AppendHelpValue(IHostEnvironment env, CmdParser owner, StringBuilder builder, object value)
{
if (value == null)
builder.Append("{}");
@@ -2006,19 +2006,19 @@ private void AppendHelpValue(IExceptionContext ectx, StringBuilder builder, obje
foreach (object o in (System.Array)value)
{
builder.Append(pre);
- AppendHelpValue(ectx, builder, o);
+ AppendHelpValue(env, owner, builder, o);
pre = ", ";
}
}
else if (value is IComponentFactory)
{
string name;
- var catalog = ModuleCatalog.CreateInstance(ectx);
+ var catalog = owner._catalog.Value;
var type = value.GetType();
bool success = catalog.TryGetComponentShortName(type, out name);
Contracts.Assert(success);
- var settings = GetSettings(ectx, value, Activator.CreateInstance(type));
+ var settings = GetSettings(env, value, Activator.CreateInstance(type));
builder.Append(name);
if (!string.IsNullOrWhiteSpace(settings))
{
@@ -2035,7 +2035,7 @@ private void AppendHelpValue(IExceptionContext ectx, StringBuilder builder, obje
}
// If value differs from the default, appends the setting to sb.
- public void AppendSetting(IExceptionContext ectx, StringBuilder sb, object value, SettingsFlags flags)
+ public void AppendSetting(IHostEnvironment env, StringBuilder sb, object value, SettingsFlags flags)
{
object def = DefaultValue;
if (!IsCollection)
@@ -2043,13 +2043,13 @@ public void AppendSetting(IExceptionContext ectx, StringBuilder sb, object value
if (value == null)
{
if (def != null || IsRequired)
- AppendSettingCore(ectx, sb, "", flags);
+ AppendSettingCore(env, sb, "", flags);
}
else if (def == null || !value.Equals(def))
{
var buffer = new StringBuilder();
- if (!(value is IComponentFactory) || (GetString(ectx, value, buffer) != GetString(ectx, def, buffer)))
- AppendSettingCore(ectx, sb, value, flags);
+ if (!(value is IComponentFactory) || (GetString(env, value, buffer) != GetString(env, def, buffer)))
+ AppendSettingCore(env, sb, value, flags);
}
return;
}
@@ -2076,10 +2076,10 @@ public void AppendSetting(IExceptionContext ectx, StringBuilder sb, object value
}
foreach (object x in vals)
- AppendSettingCore(ectx, sb, x, flags);
+ AppendSettingCore(env, sb, x, flags);
}
- private void AppendSettingCore(IExceptionContext ectx, StringBuilder sb, object value, SettingsFlags flags)
+ private void AppendSettingCore(IHostEnvironment env, StringBuilder sb, object value, SettingsFlags flags)
{
if (sb.Length > 0)
sb.Append(" ");
@@ -2100,11 +2100,11 @@ private void AppendSettingCore(IExceptionContext ectx, StringBuilder sb, object
else if (value is bool)
sb.Append((bool)value ? "+" : "-");
else if (IsCustomItemType)
- AppendCustomItem(ectx, _infoCustom, value, flags, sb);
+ AppendCustomItem(env, _infoCustom, value, flags, sb);
else if (IsComponentFactory)
{
var buffer = new StringBuilder();
- sb.Append(GetString(ectx, value, buffer));
+ sb.Append(GetString(env, value, buffer));
}
else
sb.Append(value.ToString());
@@ -2129,7 +2129,7 @@ private void ExtractTag(object value, out string tag, out object newValue)
// If value differs from the default, return the string representation of 'value',
// or an IEnumerable of string representations if 'value' is an array.
- public IEnumerable GetSettingStrings(IExceptionContext ectx, object value, StringBuilder buffer)
+ public IEnumerable GetSettingStrings(IHostEnvironment env, object value, StringBuilder buffer)
{
object def = DefaultValue;
@@ -2138,12 +2138,12 @@ public IEnumerable GetSettingStrings(IExceptionContext ectx, object valu
if (value == null)
{
if (def != null || IsRequired)
- yield return GetString(ectx, value, buffer);
+ yield return GetString(env, value, buffer);
}
else if (def == null || !value.Equals(def))
{
- if (!(value is IComponentFactory) || (GetString(ectx, value, buffer) != GetString(ectx, def, buffer)))
- yield return GetString(ectx, value, buffer);
+ if (!(value is IComponentFactory) || (GetString(env, value, buffer) != GetString(env, def, buffer)))
+ yield return GetString(env, value, buffer);
}
yield break;
}
@@ -2171,10 +2171,10 @@ public IEnumerable GetSettingStrings(IExceptionContext ectx, object valu
}
foreach (object x in vals)
- yield return GetString(ectx, x, buffer);
+ yield return GetString(env, x, buffer);
}
- private string GetString(IExceptionContext ectx, object value, StringBuilder buffer)
+ private string GetString(IHostEnvironment env, object value, StringBuilder buffer)
{
if (value == null)
return "{}";
@@ -2192,12 +2192,13 @@ private string GetString(IExceptionContext ectx, object value, StringBuilder buf
if (value is IComponentFactory)
{
string name;
- var catalog = ModuleCatalog.CreateInstance(ectx);
+ // TODO: ModuleCatalog should be on env
+ var catalog = ModuleCatalog.CreateInstance(env);
var type = value.GetType();
bool isModuleComponent = catalog.TryGetComponentShortName(type, out name);
if (isModuleComponent)
{
- var settings = GetSettings(ectx, value, Activator.CreateInstance(type));
+ var settings = GetSettings(env, value, Activator.CreateInstance(type));
buffer.Clear();
buffer.Append(name);
if (!string.IsNullOrWhiteSpace(settings))
@@ -2208,20 +2209,12 @@ private string GetString(IExceptionContext ectx, object value, StringBuilder buf
}
return buffer.ToString();
}
- else if (value is ICommandLineComponentFactory)
- {
- return value.ToString();
- }
- else
- {
- throw ectx.Except($"IComponentFactory instances either need to be EntryPointComponents or implement {nameof(ICommandLineComponentFactory)}.");
- }
}
return value.ToString();
}
- public string GetFullHelpText(IExceptionContext ectx)
+ public string GetFullHelpText(IHostEnvironment env, CmdParser owner)
{
if (IsHidden)
return null;
@@ -2248,7 +2241,7 @@ public string GetFullHelpText(IExceptionContext ectx)
if (builder.Length > 0)
builder.Append(" ");
builder.Append("Default value:'");
- AppendHelpValue(ectx, builder, DefaultValue);
+ AppendHelpValue(env, owner, builder, DefaultValue);
builder.Append('\'');
}
if (Utils.Size(ShortNames) != 0)
diff --git a/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs b/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs
index 78ec31205f..c09c7c8e70 100644
--- a/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs
+++ b/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs
@@ -2,15 +2,13 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
+using Microsoft.ML.Runtime.CommandLine;
+using Microsoft.ML.Runtime.Internal.Utilities;
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
-using System.IO;
-using System.IO.Compression;
using System.Linq;
using System.Reflection;
-using Microsoft.ML.Runtime.Internal.Utilities;
-using Microsoft.ML.Runtime.CommandLine;
// REVIEW: Determine ideal namespace.
namespace Microsoft.ML.Runtime
@@ -22,8 +20,17 @@ namespace Microsoft.ML.Runtime
/// types for component instantiation. Each component may also specify an "arguments object" that should
/// be provided at instantiation time.
///
- public static class ComponentCatalog
+ public sealed class ComponentCatalog
{
+ internal ComponentCatalog()
+ {
+ _lock = new object();
+ _cachedAssemblies = new HashSet();
+ _classesByKey = new ConcurrentDictionary();
+ _classes = new ConcurrentQueue();
+ _signatures = new ConcurrentDictionary();
+ }
+
///
/// Provides information on an instantiatable component, aka, loadable class.
///
@@ -248,201 +255,18 @@ public object CreateArguments()
}
}
- ///
- /// Debug reporting level.
- ///
- public static int DebugLevel = 1;
-
- // Do not initialize this one - the initial null value is used as a "flag" to prime things.
- private static ConcurrentQueue _assemblyQueue;
-
- // The assemblies that are loaded by Reflection.LoadAssembly or Assembly.Load* after we started tracking
- // the load events. We will provide assembly resolving for these assemblies. This is created simultaneously
- // with s_assemblyQueue.
- private static ConcurrentDictionary _loadedAssemblies;
-
- // This lock protects s_cachedAssemblies and s_cachedPaths only. The collection of ClassInfos is concurrent
+ // This lock protects _cachedAssemblies only. The collection of ClassInfos is concurrent
// so needs no protection.
- private static object _lock = new object();
- private static HashSet _cachedAssemblies = new HashSet();
- private static HashSet _cachedPaths = new HashSet();
+ private readonly object _lock;
+ private readonly HashSet _cachedAssemblies;
// Map from key/name to loadable class. Note that the same ClassInfo may appear
// multiple times. For the set of unique infos, use s_classes.
- private static ConcurrentDictionary _classesByKey = new ConcurrentDictionary();
+ private readonly ConcurrentDictionary _classesByKey;
// The unique ClassInfos and Signatures.
- private static ConcurrentQueue _classes = new ConcurrentQueue();
- private static ConcurrentDictionary _signatures = new ConcurrentDictionary();
-
- public static string[] FilePrefixesToAvoid = new string[] {
- "api-ms-win",
- "clr",
- "coreclr",
- "dbgshim",
- "ext-ms-win",
- "microsoft.bond.",
- "microsoft.cosmos.",
- "microsoft.csharp",
- "microsoft.data.",
- "microsoft.hpc.",
- "microsoft.live.",
- "microsoft.platformbuilder.",
- "microsoft.visualbasic",
- "microsoft.visualstudio.",
- "microsoft.win32",
- "microsoft.windowsapicodepack.",
- "microsoft.windowsazure.",
- "mscor",
- "msvc",
- "petzold.",
- "roslyn.",
- "sho",
- "sni",
- "sqm",
- "system.",
- "zlib",
- };
-
- private static bool ShouldSkipPath(string path)
- {
- string name = Path.GetFileName(path).ToLowerInvariant();
- switch (name)
- {
- case "cqo.dll":
- case "fasttreenative.dll":
- case "libiomp5md.dll":
- case "libvw.dll":
- case "matrixinterf.dll":
- case "microsoft.ml.neuralnetworks.gpucuda.dll":
- case "mklimports.dll":
- case "microsoft.research.controls.decisiontrees.dll":
- case "microsoft.ml.neuralnetworks.sse.dll":
- case "neuraltreeevaluator.dll":
- case "optimizationbuilderdotnet.dll":
- case "parallelcommunicator.dll":
- case "microsoft.ml.runtime.runtests.dll":
- case "scopecompiler.dll":
- case "tbb.dll":
- case "internallearnscope.dll":
- case "unmanagedlib.dll":
- case "vcclient.dll":
- case "libxgboost.dll":
- case "zedgraph.dll":
- case "__scopecodegen__.dll":
- case "cosmosClientApi.dll":
- return true;
- }
-
- foreach (var s in FilePrefixesToAvoid)
- {
- if (name.StartsWith(s))
- return true;
- }
-
- return false;
- }
-
- ///
- /// This loads assemblies that are in our "root" directory (where this assembly is) and caches
- /// information for the loadable classes in loaded assemblies.
- ///
- private static void CacheLoadedAssemblies()
- {
- // The target assembly is the one containing LoadableClassAttributeBase. If an assembly doesn't reference
- // the target, then we don't want to scan its assembly attributes (there's no point in doing so).
- var target = typeof(LoadableClassAttributeBase).Assembly;
-
- lock (_lock)
- {
- if (_assemblyQueue == null)
- {
- // Create the loaded assembly queue and dictionary, set up the AssemblyLoad / AssemblyResolve
- // event handlers and populate the queue / dictionary with all assemblies that are currently loaded.
- Contracts.Assert(_assemblyQueue == null);
- Contracts.Assert(_loadedAssemblies == null);
-
- _assemblyQueue = new ConcurrentQueue();
- _loadedAssemblies = new ConcurrentDictionary();
-
- AppDomain.CurrentDomain.AssemblyLoad += CurrentDomainAssemblyLoad;
- AppDomain.CurrentDomain.AssemblyResolve += CurrentDomainAssemblyResolve;
-
- foreach (Assembly a in AppDomain.CurrentDomain.GetAssemblies())
- {
- // Ignore dynamic assemblies.
- if (a.IsDynamic)
- continue;
-
- _assemblyQueue.Enqueue(a);
- if (!_loadedAssemblies.TryAdd(a.FullName, a))
- {
- // Duplicate loading.
- Console.Error.WriteLine("Duplicate loaded assembly '{0}'", a.FullName);
- }
- }
-
- // Load all assemblies in our directory.
- var moduleName = typeof(ComponentCatalog).Module.FullyQualifiedName;
-
- // If were are loaded in the context of SQL CLR then the FullyQualifiedName and Name properties are set to
- // string "" and we skip scanning current directory.
- if (moduleName != "")
- {
- string dir = Path.GetDirectoryName(moduleName);
- LoadAssembliesInDir(dir, true);
- dir = Path.Combine(dir, "AutoLoad");
- LoadAssembliesInDir(dir, true);
- }
- }
-
- Contracts.AssertValue(_assemblyQueue);
- Contracts.AssertValue(_loadedAssemblies);
-
- Assembly assembly;
- while (_assemblyQueue.TryDequeue(out assembly))
- {
- if (!_cachedAssemblies.Add(assembly.FullName))
- continue;
-
- if (assembly != target)
- {
- bool found = false;
- var targetName = target.GetName();
- foreach (var name in assembly.GetReferencedAssemblies())
- {
- if (name.Name == targetName.Name)
- {
- found = true;
- break;
- }
- }
- if (!found)
- continue;
- }
-
- int added = 0;
- foreach (LoadableClassAttributeBase attr in assembly.GetCustomAttributes(typeof(LoadableClassAttributeBase)))
- {
- MethodInfo getter = null;
- ConstructorInfo ctor = null;
- MethodInfo create = null;
- bool requireEnvironment = false;
- if (attr.InstanceType != typeof(void) && !TryGetIniters(attr.InstanceType, attr.LoaderType, attr.CtorTypes, out getter, out ctor, out create, out requireEnvironment))
- {
- Console.Error.WriteLine(
- "CacheClassesFromAssembly: can't instantiate loadable class {0} with name {1}",
- attr.InstanceType.Name, attr.LoadNames[0]);
- Contracts.Assert(getter == null && ctor == null && create == null);
- }
- var info = new LoadableClassInfo(attr, getter, ctor, create, requireEnvironment);
-
- AddClass(info, attr.LoadNames);
- added++;
- }
- }
- }
- }
+ private readonly ConcurrentQueue _classes;
+ private readonly ConcurrentDictionary _signatures;
private static bool TryGetIniters(Type instType, Type loaderType, Type[] parmTypes,
out MethodInfo getter, out ConstructorInfo ctor, out MethodInfo create, out bool requireEnvironment)
@@ -472,7 +296,7 @@ private static bool TryGetIniters(Type instType, Type loaderType, Type[] parmTyp
return false;
}
- private static void AddClass(LoadableClassInfo info, string[] loadNames)
+ private void AddClass(LoadableClassInfo info, string[] loadNames)
{
_classes.Enqueue(info);
foreach (var sigType in info.SignatureTypes)
@@ -528,157 +352,53 @@ private static MethodInfo FindCreateMethod(Type instType, Type loaderType, Type[
return meth;
}
- private static void LoadAssembliesInDir(string dir, bool filter)
- {
- if (!Directory.Exists(dir))
- return;
-
- // Load all dlls in the given directory.
- var paths = Directory.EnumerateFiles(dir, "*.dll");
- foreach (string path in paths)
- {
- if (filter && ShouldSkipPath(path))
- continue;
- // Loading the assembly is enough because of our event handler.
- LoadAssembly(path);
- }
- }
-
- private static void CurrentDomainAssemblyLoad(object sender, AssemblyLoadEventArgs args)
- {
- // Don't try to index dynamic generated assembly
- if (args.LoadedAssembly.IsDynamic)
- return;
- _assemblyQueue.Enqueue(args.LoadedAssembly);
- if (!_loadedAssemblies.TryAdd(args.LoadedAssembly.FullName, args.LoadedAssembly))
- {
- // Duplicate loading.
- Console.Error.WriteLine("Duplicate loading of the assembly '{0}'", args.LoadedAssembly.FullName);
- }
- }
-
- private static Assembly CurrentDomainAssemblyResolve(object sender, ResolveEventArgs args)
- {
- // REVIEW: currently, the resolving happens on exact matches of the full name.
- // This has proved to work with the C# transform. We might need to change the resolving logic when the need arises.
- Assembly found;
- if (_loadedAssemblies.TryGetValue(args.Name, out found))
- return found;
- return null;
- }
-
///
- /// Given an assembly path, load the assembly.
+ /// Registers all the components in the specified assembly by looking for loadable classes
+ /// and adding them to the catalog.
///
- public static Assembly LoadAssembly(string path)
+ ///
+ /// The assembly to register.
+ ///
+ ///
+ /// true to throw an exception if there are errors with registering the components;
+ /// false to skip any errors.
+ ///
+ public void RegisterAssembly(Assembly assembly, bool throwOnError = true)
{
- try
- {
- return LoadFrom(path);
- }
- catch (Exception e)
- {
- if (DebugLevel > 2)
- Console.Error.WriteLine("Could not load assembly {0}:\n{1}", path, e.ToString());
- return null;
- }
- }
-
- private static Assembly LoadFrom(string path)
- {
- Contracts.CheckNonEmpty(path, nameof(path));
- return Assembly.LoadFrom(path);
- }
-
- ///
- /// Make sure the given assemblies are loaded and that their loadable classes have been catalogued.
- ///
- public static void CacheClassesExtra(string[] assemblies)
- {
- if (Utils.Size(assemblies) > 0)
+ lock (_lock)
{
- lock (_lock)
+ if (_cachedAssemblies.Add(assembly.FullName))
{
- foreach (string path in assemblies)
+ foreach (LoadableClassAttributeBase attr in assembly.GetCustomAttributes(typeof(LoadableClassAttributeBase)))
{
- if (!_cachedPaths.Add(path))
- continue;
-
- Exception ex = null;
- try
- {
- // REVIEW: Will LoadFrom ever return null?
- var assem = LoadFrom(path);
- if (assem != null)
- continue;
- }
- catch (Exception e)
- {
- ex = e;
- }
-
- // If it is a zip file, load it that way.
- ZipArchive zip;
- try
- {
- zip = ZipFile.OpenRead(path);
- }
- catch (Exception e)
- {
- // Couldn't load as an assembly and not a zip, so warn the user.
- ex = ex ?? e;
- Console.Error.WriteLine("Warning: Could not load '{0}': {1}", path, ex.Message);
- continue;
- }
-
- string dir;
- try
- {
- dir = CreateTempDirectory();
- }
- catch (Exception e)
- {
- throw Contracts.ExceptIO(e, "Creating temp directory for extra assembly zip extraction failed: '{0}'", path);
- }
-
- try
- {
- zip.ExtractToDirectory(dir);
- }
- catch (Exception e)
+ MethodInfo getter = null;
+ ConstructorInfo ctor = null;
+ MethodInfo create = null;
+ bool requireEnvironment = false;
+ if (attr.InstanceType != typeof(void) && !TryGetIniters(attr.InstanceType, attr.LoaderType, attr.CtorTypes, out getter, out ctor, out create, out requireEnvironment))
{
- throw Contracts.ExceptIO(e, "Extracting extra assembly zip failed: '{0}'", path);
+ if (throwOnError)
+ {
+ throw new InvalidOperationException(string.Format(
+ "Can't instantiate loadable class {0} with name {1}",
+ attr.InstanceType.Name, attr.LoadNames[0]));
+ }
+ Contracts.Assert(getter == null && ctor == null && create == null);
}
+ var info = new LoadableClassInfo(attr, getter, ctor, create, requireEnvironment);
- LoadAssembliesInDir(dir, false);
+ AddClass(info, attr.LoadNames);
}
}
}
-
- CacheLoadedAssemblies();
- }
-
- private static string CreateTempDirectory()
- {
- string dir = GetTempPath();
- Directory.CreateDirectory(dir);
- return dir;
- }
-
- private static string GetTempPath()
- {
- Guid guid = Guid.NewGuid();
- return Path.GetFullPath(Path.Combine(Path.GetTempPath(), "TLC_" + guid.ToString()));
}
///
/// Return an array containing information for all instantiatable components.
/// If provided, the given set of assemblies is loaded first.
///
- public static LoadableClassInfo[] GetAllClasses(string[] assemblies = null)
+ public LoadableClassInfo[] GetAllClasses()
{
- CacheClassesExtra(assemblies);
-
return _classes.ToArray();
}
@@ -686,13 +406,11 @@ public static LoadableClassInfo[] GetAllClasses(string[] assemblies = null)
/// Return an array containing information for instantiatable components with the given
/// signature and base type. If provided, the given set of assemblies is loaded first.
///
- public static LoadableClassInfo[] GetAllDerivedClasses(Type typeBase, Type typeSig, string[] assemblies = null)
+ public LoadableClassInfo[] GetAllDerivedClasses(Type typeBase, Type typeSig)
{
Contracts.CheckValue(typeBase, nameof(typeBase));
Contracts.CheckValueOrNull(typeSig);
- CacheClassesExtra(assemblies);
-
// Apply the default.
if (typeSig == null)
typeSig = typeof(SignatureDefault);
@@ -706,10 +424,8 @@ public static LoadableClassInfo[] GetAllDerivedClasses(Type typeBase, Type typeS
/// Return an array containing all the known signature types. If provided, the given set of assemblies
/// is loaded first.
///
- public static Type[] GetAllSignatureTypes(string[] assemblies = null)
+ public Type[] GetAllSignatureTypes()
{
- CacheClassesExtra(assemblies);
-
return _signatures.Select(kvp => kvp.Key).ToArray();
}
@@ -726,58 +442,47 @@ public static string SignatureToString(Type sig)
return kind;
}
- private static LoadableClassInfo FindClassCore(LoadableClassInfo.Key key)
+ private LoadableClassInfo FindClassCore(LoadableClassInfo.Key key)
{
LoadableClassInfo info;
- if (_classesByKey.TryGetValue(key, out info))
- return info;
-
- CacheLoadedAssemblies();
-
if (_classesByKey.TryGetValue(key, out info))
return info;
return null;
}
- public static LoadableClassInfo[] FindLoadableClasses(string name)
+ public LoadableClassInfo[] FindLoadableClasses(string name)
{
name = name.ToLowerInvariant().Trim();
- CacheLoadedAssemblies();
-
var res = _classes
.Where(ci => ci.LoadNames.Select(n => n.ToLowerInvariant().Trim()).Contains(name))
.ToArray();
return res;
}
- public static LoadableClassInfo[] FindLoadableClasses()
+ public LoadableClassInfo[] FindLoadableClasses()
{
- CacheLoadedAssemblies();
-
return _classes
.Where(ci => ci.SignatureTypes.Contains(typeof(TSig)))
.ToArray();
}
- public static LoadableClassInfo[] FindLoadableClasses()
+ public LoadableClassInfo[] FindLoadableClasses()
{
// REVIEW: this and above methods perform a linear search over all the loadable classes.
// On 6/15/2015, TLC release build contained 431 of them, so adding extra lookups looks unnecessary at this time.
- CacheLoadedAssemblies();
-
return _classes
.Where(ci => ci.ArgType == typeof(TArgs) && ci.SignatureTypes.Contains(typeof(TSig)))
.ToArray();
}
- public static LoadableClassInfo GetLoadableClassInfo(string loadName)
+ public LoadableClassInfo GetLoadableClassInfo(string loadName)
{
return GetLoadableClassInfo(loadName, typeof(TSig));
}
- public static LoadableClassInfo GetLoadableClassInfo(string loadName, Type signatureType)
+ public LoadableClassInfo GetLoadableClassInfo(string loadName, Type signatureType)
{
Contracts.CheckParam(signatureType.BaseType == typeof(MulticastDelegate), nameof(signatureType), "signatureType must be a delegate type");
Contracts.CheckValueOrNull(loadName);
@@ -815,7 +520,7 @@ private static bool TryCreateInstance(IHostEnvironment env, Type signature
env.CheckValueOrNull(name);
string nameLower = (name ?? "").ToLowerInvariant().Trim();
- LoadableClassInfo info = FindClassCore(new LoadableClassInfo.Key(nameLower, signatureType));
+ LoadableClassInfo info = env.ComponentCatalog.FindClassCore(new LoadableClassInfo.Key(nameLower, signatureType));
if (info == null)
{
result = null;
diff --git a/src/Microsoft.ML.Core/Data/IHostEnvironment.cs b/src/Microsoft.ML.Core/Data/IHostEnvironment.cs
index b463e52a8e..bd08a75c9a 100644
--- a/src/Microsoft.ML.Core/Data/IHostEnvironment.cs
+++ b/src/Microsoft.ML.Core/Data/IHostEnvironment.cs
@@ -46,6 +46,11 @@ public interface IHostEnvironment : IChannelProvider, IProgressChannelProvider
///
bool IsCancelled { get; }
+ ///
+ /// The catalog of loadable components () that are available in this host.
+ ///
+ ComponentCatalog ComponentCatalog { get; }
+
///
/// Return a file handle for an input "file".
///
diff --git a/src/Microsoft.ML.Core/Data/ServerChannel.cs b/src/Microsoft.ML.Core/Data/ServerChannel.cs
index fd0e1f2603..4b6db41541 100644
--- a/src/Microsoft.ML.Core/Data/ServerChannel.cs
+++ b/src/Microsoft.ML.Core/Data/ServerChannel.cs
@@ -171,16 +171,16 @@ public interface IServer : IDisposable
/// for example, if a user opted to remove all implementations of and
/// the associated for security reasons.
///
- public static IServerFactory CreateDefaultServerFactoryOrNull(IExceptionContext ectx)
+ public static IServerFactory CreateDefaultServerFactoryOrNull(IHostEnvironment env)
{
- Contracts.CheckValue(ectx, nameof(ectx));
+ Contracts.CheckValue(env, nameof(env));
// REVIEW: There should be a better way. There currently isn't,
// but there should be. This is pretty horrifying, but it is preferable to
// the alternative of having core components depend on an actual server
// implementation, since we want those to be removable because of security
// concerns in certain environments (since not everyone will be wild about
// web servers popping up everywhere).
- var cat = ModuleCatalog.CreateInstance(ectx);
+ var cat = ModuleCatalog.CreateInstance(env);
ModuleCatalog.ComponentInfo component;
if (!cat.TryFindComponent(typeof(IServerFactory), "mini", out component))
return null;
diff --git a/src/Microsoft.ML.Core/EntryPoints/ModuleCatalog.cs b/src/Microsoft.ML.Core/EntryPoints/ModuleCatalog.cs
index 60511bfd39..a34242b6f4 100644
--- a/src/Microsoft.ML.Core/EntryPoints/ModuleCatalog.cs
+++ b/src/Microsoft.ML.Core/EntryPoints/ModuleCatalog.cs
@@ -152,7 +152,6 @@ internal ComponentInfo(IExceptionContext ectx, Type interfaceType, string kind,
}
}
- private static volatile ModuleCatalog _instance;
private readonly EntryPointInfo[] _entryPoints;
private readonly Dictionary _entryPointMap;
@@ -167,15 +166,15 @@ public IEnumerable AllEntryPoints()
return _entryPoints.AsEnumerable();
}
- private ModuleCatalog(IExceptionContext ectx)
+ private ModuleCatalog(IHostEnvironment env)
{
- Contracts.AssertValue(ectx);
+ Contracts.AssertValue(env);
_entryPointMap = new Dictionary();
_componentMap = new Dictionary();
_components = new List();
- var moduleClasses = ComponentCatalog.FindLoadableClasses();
+ var moduleClasses = env.ComponentCatalog.FindLoadableClasses();
var entryPoints = new List();
foreach (var lc in moduleClasses)
@@ -189,7 +188,7 @@ private ModuleCatalog(IExceptionContext ectx)
if (attr == null)
continue;
- var info = new EntryPointInfo(ectx, methodInfo, attr,
+ var info = new EntryPointInfo(env, methodInfo, attr,
methodInfo.GetCustomAttributes(typeof(ObsoleteAttribute), false).FirstOrDefault() as ObsoleteAttribute);
entryPoints.Add(info);
@@ -205,9 +204,9 @@ private ModuleCatalog(IExceptionContext ectx)
// Scan for components.
// First scan ourself, and then all nested types, for component info.
- ScanForComponents(ectx, type);
+ ScanForComponents(env, type);
foreach (var nestedType in type.GetTypeInfo().GetNestedTypes())
- ScanForComponents(ectx, nestedType);
+ ScanForComponents(env, nestedType);
}
_entryPoints = entryPoints.ToArray();
}
@@ -274,17 +273,13 @@ private static bool IsValidName(string name)
}
///
- /// Create a module catalog (or reuse the one created before).
+ /// Create a module catalog.
///
- /// The exception context to use to report errors while scanning the assemblies.
- public static ModuleCatalog CreateInstance(IExceptionContext ectx)
+ /// The host environment and exception context to use to report errors while scanning the assemblies.
+ public static ModuleCatalog CreateInstance(IHostEnvironment env)
{
- Contracts.CheckValueOrNull(ectx);
-#pragma warning disable 420 // volatile with Interlocked.CompareExchange.
- if (_instance == null)
- Interlocked.CompareExchange(ref _instance, new ModuleCatalog(ectx), null);
-#pragma warning restore 420
- return _instance;
+ Contracts.CheckValue(env, nameof(env));
+ return new ModuleCatalog(env);
}
public bool TryFindEntryPoint(string name, out EntryPointInfo entryPoint)
diff --git a/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs b/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs
index d4ff5ccd96..b195de6108 100644
--- a/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs
+++ b/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs
@@ -6,7 +6,6 @@
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO;
-using System.Threading;
namespace Microsoft.ML.Runtime.Data
{
@@ -382,6 +381,8 @@ public void RemoveListener(Action listenerFunc)
public bool IsCancelled { get; protected set; }
+ public ComponentCatalog ComponentCatalog { get; }
+
public override int Depth => 0;
protected bool IsDisposed => _tempFiles == null;
@@ -402,6 +403,7 @@ protected HostEnvironmentBase(IRandom rand, bool verbose, int conc,
_tempLock = new object();
_tempFiles = new List();
Root = this as TEnv;
+ ComponentCatalog = new ComponentCatalog();
}
///
@@ -422,6 +424,7 @@ protected HostEnvironmentBase(HostEnvironmentBase source, IRandom rand, bo
Root = source.Root;
ListenerDict = source.ListenerDict;
ProgressTracker = source.ProgressTracker;
+ ComponentCatalog = source.ComponentCatalog;
}
///
diff --git a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs
index f6670394ee..837c3beb3e 100644
--- a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs
@@ -119,7 +119,7 @@ public override void Run()
using (var ch = Host.Start(LoadName))
using (var server = InitServer(ch))
{
- var settings = CmdParser.GetSettings(ch, Args, new Arguments());
+ var settings = CmdParser.GetSettings(Host, Args, new Arguments());
string cmd = string.Format("maml.exe {0} {1}", LoadName, settings);
ch.Info(cmd);
@@ -557,7 +557,7 @@ private FoldResult RunFold(int fold)
var bindable = ScoreUtils.GetSchemaBindableMapper(host, predictor, scorerFactorySettings: _scorer as ICommandLineComponentFactory);
ch.AssertValue(bindable);
var mapper = bindable.Bind(host, testData.Schema);
- var scorerComp = _scorer ?? ScoreUtils.GetScorerComponent(mapper);
+ var scorerComp = _scorer ?? ScoreUtils.GetScorerComponent(host, mapper);
IDataScorerTransform scorePipe = scorerComp.CreateComponent(host, testData.Data, mapper, trainData.Schema);
// Save per-fold model.
diff --git a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs
index 301603b14a..014fdf87f8 100644
--- a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs
@@ -120,7 +120,7 @@ private void RunCore(IChannel ch)
var mapper = bindable.Bind(Host, schema);
if (scorer == null)
- scorer = ScoreUtils.GetScorerComponent(mapper);
+ scorer = ScoreUtils.GetScorerComponent(Host, mapper);
loader = CompositeDataLoader.ApplyTransform(Host, loader, "Scorer", scorer.ToString(),
(env, view) => scorer.CreateComponent(env, view, mapper, trainSchema));
@@ -284,7 +284,7 @@ private static TScorerFactory GetScorerComponentAndMapper(
mapper = bindable.Bind(env, schema);
if (scorerFactory != null)
return scorerFactory;
- return GetScorerComponent(mapper);
+ return GetScorerComponent(env, mapper);
}
///
@@ -292,12 +292,15 @@ private static TScorerFactory GetScorerComponentAndMapper(
/// metadata on the first column of the mapper. If that text is found and maps to a scorer loadable class,
/// that component is used. Otherwise, the GenericScorer is used.
///
+ /// The host environment..
/// The schema bound mapper to get the default scorer..
/// An optional suffix to append to the default column names.
public static TScorerFactory GetScorerComponent(
+ IHostEnvironment environment,
ISchemaBoundMapper mapper,
string suffix = null)
{
+ Contracts.CheckValue(environment, nameof(environment));
Contracts.AssertValue(mapper);
ComponentCatalog.LoadableClassInfo info = null;
@@ -307,7 +310,7 @@ public static TScorerFactory GetScorerComponent(
!scoreKind.IsEmpty)
{
var loadName = scoreKind.ToString();
- info = ComponentCatalog.GetLoadableClassInfo(loadName);
+ info = environment.ComponentCatalog.GetLoadableClassInfo(loadName);
if (info == null || !typeof(IDataScorerTransform).IsAssignableFrom(info.Type))
info = null;
}
diff --git a/src/Microsoft.ML.Data/Commands/TestCommand.cs b/src/Microsoft.ML.Data/Commands/TestCommand.cs
index ca32da0ddd..1630d9af57 100644
--- a/src/Microsoft.ML.Data/Commands/TestCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/TestCommand.cs
@@ -67,7 +67,7 @@ public override void Run()
using (var ch = Host.Start(command))
using (var server = InitServer(ch))
{
- var settings = CmdParser.GetSettings(ch, Args, new Arguments());
+ var settings = CmdParser.GetSettings(Host, Args, new Arguments());
ch.Info("maml.exe {0} {1}", command, settings);
SendTelemetry(Host);
diff --git a/src/Microsoft.ML.Data/Commands/TrainCommand.cs b/src/Microsoft.ML.Data/Commands/TrainCommand.cs
index acc03b743f..1a90881464 100644
--- a/src/Microsoft.ML.Data/Commands/TrainCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/TrainCommand.cs
@@ -108,7 +108,7 @@ public override void Run()
using (var ch = Host.Start(command))
using (var server = InitServer(ch))
{
- var settings = CmdParser.GetSettings(ch, Args, new Arguments());
+ var settings = CmdParser.GetSettings(Host, Args, new Arguments());
string cmd = string.Format("maml.exe {0} {1}", command, settings);
ch.Info(cmd);
diff --git a/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs b/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs
index 270f7c3b03..63c8691747 100644
--- a/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs
@@ -94,7 +94,7 @@ public override void Run()
using (var ch = Host.Start(LoadName))
using (var server = InitServer(ch))
{
- var settings = CmdParser.GetSettings(ch, Args, new Arguments());
+ var settings = CmdParser.GetSettings(Host, Args, new Arguments());
string cmd = string.Format("maml.exe {0} {1}", LoadName, settings);
ch.Info(cmd);
diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs
index 582212738a..9991956fcb 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs
@@ -778,7 +778,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010003, // Number of blocks to put in the shuffle pool
verReadableCur: 0x00010003,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(BinaryLoader).Assembly.FullName);
}
private BinaryLoader(Arguments args, IHost host, Stream stream, bool leaveOpen)
diff --git a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs
index cb18bdce1d..d71f6b9ac5 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs
@@ -71,7 +71,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010002, // Added transform tags and args strings
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(CompositeDataLoader).Assembly.FullName);
}
// The composition of loader plus transforms in order.
diff --git a/src/Microsoft.ML.Data/DataLoadSave/PartitionedFileLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/PartitionedFileLoader.cs
index 27ac28c717..6c52cefa56 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/PartitionedFileLoader.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/PartitionedFileLoader.cs
@@ -60,7 +60,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoadName);
+ loaderSignature: LoadName,
+ loaderAssemblyName: typeof(PartitionedFileLoader).Assembly.FullName);
}
public class Arguments
diff --git a/src/Microsoft.ML.Data/DataLoadSave/PartitionedPathParser.cs b/src/Microsoft.ML.Data/DataLoadSave/PartitionedPathParser.cs
index 70d8f898ab..c10fe09116 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/PartitionedPathParser.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/PartitionedPathParser.cs
@@ -91,7 +91,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoadName);
+ loaderSignature: LoadName,
+ loaderAssemblyName: typeof(SimplePartitionedPathParser).Assembly.FullName);
}
private IHost _host;
@@ -214,7 +215,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoadName);
+ loaderSignature: LoadName,
+ loaderAssemblyName: typeof(ParquetPartitionedPathParser).Assembly.FullName);
}
public ParquetPartitionedPathParser()
diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs
index 3663c93cc4..195befe518 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs
@@ -960,7 +960,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x0001000B, // Header now retained if used and present
verReadableCur: 0x0001000A,
verWeCanReadBack: 0x00010009,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(TextLoader).Assembly.FullName);
}
///
@@ -1180,13 +1181,13 @@ private static bool TryParseSchema(IHost host, IMultiStreamSource files,
goto LDone;
// Make sure the loader binds to us.
- var info = ComponentCatalog.GetLoadableClassInfo(loader.Name);
+ var info = host.ComponentCatalog.GetLoadableClassInfo(loader.Name);
if (info.Type != typeof(IDataLoader) || info.ArgType != typeof(Arguments))
goto LDone;
var argsNew = new Arguments();
// Copy the non-core arguments to the new args (we already know that all the core arguments are default).
- var parsed = CmdParser.ParseArguments(host, CmdParser.GetSettings(ch, args, new Arguments()), argsNew);
+ var parsed = CmdParser.ParseArguments(host, CmdParser.GetSettings(host, args, new Arguments()), argsNew);
ch.Assert(parsed);
// Copy the core arguments to the new args.
if (!CmdParser.ParseArguments(host, loader.GetSettingsString(), argsNew, typeof(ArgumentsCore), msg => ch.Error(msg)))
diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs
index 64cf7f7faf..73f512b478 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs
@@ -468,7 +468,7 @@ private string CreateLoaderArguments(ISchema schema, ValueWriter[] pipes, bool h
sb.Append(" col=");
if (!column.TryUnparse(sb))
{
- var settings = CmdParser.GetSettings(ch, column, new TextLoader.Column());
+ var settings = CmdParser.GetSettings(_host, column, new TextLoader.Column());
CmdQuoter.QuoteValue(settings, sb, true);
}
if (type.IsVector && !type.IsKnownSizeVector && i != pipes.Length - 1)
diff --git a/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs b/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs
index 64130ed80e..cd4450615d 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs
@@ -77,7 +77,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(TransformWrapper).Assembly.FullName);
}
// Factory for SignatureLoadModel.
diff --git a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs
index d5f246689a..31ab59a100 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs
@@ -55,7 +55,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: TransformerChain.LoaderSignature);
+ loaderSignature: TransformerChain.LoaderSignature,
+ loaderAssemblyName: typeof(TransformerChain<>).Assembly.FullName);
}
///
diff --git a/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs
index 37cbe23b92..a6c12b48ed 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs
@@ -354,7 +354,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoadName);
+ loaderSignature: LoadName,
+ loaderAssemblyName: typeof(TransposeLoader).Assembly.FullName);
}
// We return the schema view's schema, because we don't necessarily want
diff --git a/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs b/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs
index ef19012bf6..799741fa1a 100644
--- a/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs
+++ b/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs
@@ -241,7 +241,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(RowToRowMapperTransform).Assembly.FullName);
}
public override ISchema Schema { get { return _bindings; } }
diff --git a/src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs b/src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs
index 58ba66e091..dfa90bfec6 100644
--- a/src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs
+++ b/src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs
@@ -222,7 +222,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(FeatureNameCollection).Assembly.FullName);
}
public static void Save(ModelSaveContext ctx, ref VBuffer> names)
diff --git a/src/Microsoft.ML.Data/Dirty/ChooseColumnsByIndexTransform.cs b/src/Microsoft.ML.Data/Dirty/ChooseColumnsByIndexTransform.cs
index da237611b9..df0e0ff049 100644
--- a/src/Microsoft.ML.Data/Dirty/ChooseColumnsByIndexTransform.cs
+++ b/src/Microsoft.ML.Data/Dirty/ChooseColumnsByIndexTransform.cs
@@ -189,7 +189,8 @@ private static VersionInfo GetVersionInfo()
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
- loaderSignatureAlt: LoaderSignatureOld);
+ loaderSignatureAlt: LoaderSignatureOld,
+ loaderAssemblyName: typeof(ChooseColumnsByIndexTransform).Assembly.FullName);
}
private readonly Bindings _bindings;
diff --git a/src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs b/src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs
index 260ff37f26..195b2ec5f6 100644
--- a/src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs
+++ b/src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs
@@ -83,7 +83,7 @@ public static Output Score(IHostEnvironment env, Input input)
ch.AssertValue(bindable);
var mapper = bindable.Bind(host, data.Schema);
- var scorer = ScoreUtils.GetScorerComponent(mapper, input.Suffix);
+ var scorer = ScoreUtils.GetScorerComponent(host, mapper, input.Suffix);
scoredPipe = scorer.CreateComponent(host, data.Data, mapper, input.PredictorModel.GetTrainingSchema(host));
ch.Done();
}
@@ -134,7 +134,7 @@ public static Output MakeScoringTransform(IHostEnvironment env, ModelInput input
ch.AssertValue(bindable);
var mapper = bindable.Bind(host, data.Schema);
- var scorer = ScoreUtils.GetScorerComponent(mapper);
+ var scorer = ScoreUtils.GetScorerComponent(host, mapper);
scoredPipe = scorer.CreateComponent(host, data.Data, mapper, input.PredictorModel.GetTrainingSchema(host));
ch.Done();
}
diff --git a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs
index a92526dfa9..2247c32637 100644
--- a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs
@@ -1034,7 +1034,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(BinaryPerInstanceEvaluator).Assembly.FullName);
}
private const int AssignedCol = 0;
diff --git a/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs
index fbbafa8775..2c476eb468 100644
--- a/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs
@@ -529,7 +529,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(ClusteringPerInstanceEvaluator).Assembly.FullName);
}
private const int ClusterIdCol = 0;
diff --git a/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs
index 84fd29ea50..c264d438b0 100644
--- a/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs
@@ -385,7 +385,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(MultiOutputRegressionPerInstanceEvaluator).Assembly.FullName);
}
private const int LabelOutput = 0;
diff --git a/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs
index 900df4ac53..205d74a526 100644
--- a/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs
@@ -671,7 +671,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010002, // Serialize the class names
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(MultiClassPerInstanceEvaluator).Assembly.FullName);
}
private const int AssignedCol = 0;
diff --git a/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs
index d4c70c6eac..9a2f22de5d 100644
--- a/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs
@@ -262,7 +262,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(QuantileRegressionPerInstanceEvaluator).Assembly.FullName);
}
private const int L1Col = 0;
diff --git a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs
index 131f16b058..d071a906e3 100644
--- a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs
@@ -523,7 +523,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(RankerPerInstanceTransform).Assembly.FullName);
}
public const string Ndcg = "NDCG";
diff --git a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs
index b20492d6dc..203050d774 100644
--- a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs
@@ -293,7 +293,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(RegressionPerInstanceEvaluator).Assembly.FullName);
}
private const int L1Col = 0;
diff --git a/src/Microsoft.ML.Data/Model/ModelHeader.cs b/src/Microsoft.ML.Data/Model/ModelHeader.cs
index 067ff12285..37a5b9ac92 100644
--- a/src/Microsoft.ML.Data/Model/ModelHeader.cs
+++ b/src/Microsoft.ML.Data/Model/ModelHeader.cs
@@ -20,11 +20,14 @@ public struct ModelHeader
public const ulong SignatureValue = 0x4C45444F4D004C4DUL;
public const ulong TailSignatureValue = 0x4D4C004D4F44454CUL;
+ private const uint VerAssemblyNameSupported = 0x00010002;
+
// These are private since they change over time. If we make them public we risk
// another assembly containing a "copy" of their value when the other assembly
// was compiled, which might not match the code that can load this.
- private const uint VerWrittenCur = 0x00010001;
- private const uint VerReadableCur = 0x00010001;
+ //private const uint VerWrittenCur = 0x00010001; // Initial
+ private const uint VerWrittenCur = 0x00010002; // Added AssemblyName
+ private const uint VerReadableCur = 0x00010002;
private const uint VerWeCanReadBack = 0x00010001;
[FieldOffset(0x00)]
@@ -85,7 +88,13 @@ public struct ModelHeader
[FieldOffset(0x88)]
public long FpLim;
- // Lots of padding....
+ // Location of the fully qualified assembly name string (in UTF-16).
+ // Note that it is legal for both to be zero.
+ [FieldOffset(0x90)]
+ public long FpAssemblyName;
+ [FieldOffset(0x98)]
+ public uint CbAssemblyName;
+
public const int Size = 0x0100;
// Utilities for writing.
@@ -116,7 +125,7 @@ public static void BeginWrite(BinaryWriter writer, out long fpMin, out ModelHead
/// The current writer position should be the end of the model blob. Records the model size, writes the string table,
/// completes and writes the header, and writes the tail.
///
- public static void EndWrite(BinaryWriter writer, long fpMin, ref ModelHeader header, NormStr.Pool pool = null)
+ public static void EndWrite(BinaryWriter writer, long fpMin, ref ModelHeader header, NormStr.Pool pool = null, string loaderAssemblyName = null)
{
Contracts.CheckValue(writer, nameof(writer));
Contracts.CheckParam(fpMin >= 0, nameof(fpMin));
@@ -157,9 +166,28 @@ public static void EndWrite(BinaryWriter writer, long fpMin, ref ModelHeader hea
Contracts.Assert(offset == header.CbStringChars);
}
+ WriteLoaderAssemblyName(writer, fpMin, ref header, loaderAssemblyName);
+
WriteHeaderAndTailCore(writer, fpMin, ref header);
}
+ private static void WriteLoaderAssemblyName(BinaryWriter writer, long fpMin, ref ModelHeader header, string loaderAssemblyName)
+ {
+ if (!string.IsNullOrEmpty(loaderAssemblyName))
+ {
+ header.FpAssemblyName = writer.FpCur() - fpMin;
+ header.CbAssemblyName = (uint)loaderAssemblyName.Length * sizeof(char);
+
+ foreach (var ch in loaderAssemblyName)
+ writer.Write((short)ch);
+ }
+ else
+ {
+ header.FpAssemblyName = 0;
+ header.CbAssemblyName = 0;
+ }
+ }
+
///
/// The current writer position should be where the tail belongs. Writes the header and tail.
/// Typically this isn't called directly unless you are doing custom string table serialization.
@@ -289,7 +317,7 @@ public static void MarshalToBytes(ref ModelHeader header, byte[] bytes)
/// Read the model header, strings, etc from reader. Also validates the header (throws if bad).
/// Leaves the reader position at the beginning of the model blob.
///
- public static void BeginRead(out long fpMin, out ModelHeader header, out string[] strings, BinaryReader reader)
+ public static void BeginRead(out long fpMin, out ModelHeader header, out string[] strings, out string loaderAssemblyName, BinaryReader reader)
{
fpMin = reader.FpCur();
@@ -298,7 +326,7 @@ public static void BeginRead(out long fpMin, out ModelHeader header, out string[
ModelHeader.MarshalFromBytes(out header, headerBytes);
Exception ex;
- if (!ModelHeader.TryValidate(ref header, reader, fpMin, out strings, out ex))
+ if (!ModelHeader.TryValidate(ref header, reader, fpMin, out strings, out loaderAssemblyName, out ex))
throw ex;
reader.Seek(header.FpModel + fpMin);
@@ -375,7 +403,10 @@ public static bool TryValidate(ref ModelHeader header, long size, out Exception
Contracts.CheckDecode(header.CbStringTable == 0);
Contracts.CheckDecode(header.FpStringChars == 0);
Contracts.CheckDecode(header.CbStringChars == 0);
- Contracts.CheckDecode(header.FpTail == header.FpModel + header.CbModel);
+ if (header.VerWritten < VerAssemblyNameSupported || header.FpAssemblyName == 0)
+ {
+ Contracts.CheckDecode(header.FpTail == header.FpModel + header.CbModel);
+ }
}
else
{
@@ -387,8 +418,34 @@ public static bool TryValidate(ref ModelHeader header, long size, out Exception
Contracts.CheckDecode(header.FpStringChars == header.FpStringTable + header.CbStringTable);
Contracts.CheckDecode(header.CbStringChars % sizeof(char) == 0);
Contracts.CheckDecode(header.FpStringChars + header.CbStringChars >= header.FpStringChars);
- Contracts.CheckDecode(header.FpTail == header.FpStringChars + header.CbStringChars);
+ if (header.VerWritten < VerAssemblyNameSupported || header.FpAssemblyName == 0)
+ {
+ Contracts.CheckDecode(header.FpTail == header.FpStringChars + header.CbStringChars);
+ }
}
+
+ if (header.VerWritten >= VerAssemblyNameSupported)
+ {
+ if (header.FpAssemblyName == 0)
+ {
+ Contracts.CheckDecode(header.CbAssemblyName == 0);
+ }
+ else
+ {
+ // the assembly name always immediately after the string table, if there is one
+ if (header.FpStringTable == 0)
+ {
+ Contracts.CheckDecode(header.FpAssemblyName == header.FpModel + header.CbModel);
+ }
+ else
+ {
+ Contracts.CheckDecode(header.FpAssemblyName == header.FpStringChars + header.CbStringChars);
+ }
+ Contracts.CheckDecode(header.CbAssemblyName % sizeof(char) == 0);
+ Contracts.CheckDecode(header.FpTail == header.FpAssemblyName + header.CbAssemblyName);
+ }
+ }
+
Contracts.CheckDecode(header.FpLim == header.FpTail + sizeof(ulong));
Contracts.CheckDecode(size == 0 || size >= header.FpLim);
@@ -405,7 +462,7 @@ public static bool TryValidate(ref ModelHeader header, long size, out Exception
///
/// Checks the validity of the header, reads the string table, etc.
///
- public static bool TryValidate(ref ModelHeader header, BinaryReader reader, long fpMin, out string[] strings, out Exception ex)
+ public static bool TryValidate(ref ModelHeader header, BinaryReader reader, long fpMin, out string[] strings, out string loaderAssemblyName, out Exception ex)
{
Contracts.CheckValue(reader, nameof(reader));
Contracts.Check(fpMin >= 0);
@@ -413,62 +470,85 @@ public static bool TryValidate(ref ModelHeader header, BinaryReader reader, long
if (!TryValidate(ref header, reader.BaseStream.Length - fpMin, out ex))
{
strings = null;
+ loaderAssemblyName = null;
return false;
}
- if (header.FpStringTable == 0)
- {
- // No strings.
- strings = null;
- ex = null;
- return true;
- }
-
try
{
long fpOrig = reader.FpCur();
- reader.Seek(header.FpStringTable + fpMin);
- Contracts.Assert(reader.FpCur() == header.FpStringTable + fpMin);
- long cstr = header.CbStringTable / sizeof(long);
- Contracts.Assert(cstr < int.MaxValue);
- long[] offsets = reader.ReadLongArray((int)cstr);
- Contracts.Assert(header.FpStringChars == reader.FpCur() - fpMin);
- Contracts.CheckDecode(offsets[cstr - 1] == header.CbStringChars);
+ StringBuilder sb = null;
+ if (header.FpStringTable == 0)
+ {
+ // No strings.
+ strings = null;
+ }
+ else
+ {
+ reader.Seek(header.FpStringTable + fpMin);
+ Contracts.Assert(reader.FpCur() == header.FpStringTable + fpMin);
+
+ long cstr = header.CbStringTable / sizeof(long);
+ Contracts.Assert(cstr < int.MaxValue);
+ long[] offsets = reader.ReadLongArray((int)cstr);
+ Contracts.Assert(header.FpStringChars == reader.FpCur() - fpMin);
+ Contracts.CheckDecode(offsets[cstr - 1] == header.CbStringChars);
+
+ strings = new string[cstr];
+ long offset = 0;
+ sb = new StringBuilder();
+ for (int i = 0; i < offsets.Length; i++)
+ {
+ Contracts.CheckDecode(header.FpStringChars + offset == reader.FpCur() - fpMin);
+
+ long offsetPrev = offset;
+ offset = offsets[i];
+ Contracts.CheckDecode(offsetPrev <= offset & offset <= header.CbStringChars);
+ Contracts.CheckDecode(offset % sizeof(char) == 0);
+ long cch = (offset - offsetPrev) / sizeof(char);
+ Contracts.CheckDecode(cch < int.MaxValue);
+
+ sb.Clear();
+ for (long ich = 0; ich < cch; ich++)
+ sb.Append((char)reader.ReadUInt16());
+ strings[i] = sb.ToString();
+ }
+ Contracts.CheckDecode(offset == header.CbStringChars);
+ Contracts.CheckDecode(header.FpStringChars + header.CbStringChars == reader.FpCur() - fpMin);
+ }
- strings = new string[cstr];
- long offset = 0;
- var sb = new StringBuilder();
- for (int i = 0; i < offsets.Length; i++)
+ if (header.VerWritten >= VerAssemblyNameSupported && header.FpAssemblyName != 0)
{
- Contracts.CheckDecode(header.FpStringChars + offset == reader.FpCur() - fpMin);
+ reader.Seek(header.FpAssemblyName + fpMin);
+ int assemblyNameLength = (int)header.CbAssemblyName / sizeof(char);
- long offsetPrev = offset;
- offset = offsets[i];
- Contracts.CheckDecode(offsetPrev <= offset & offset <= header.CbStringChars);
- Contracts.CheckDecode(offset % sizeof(char) == 0);
- long cch = (offset - offsetPrev) / sizeof(char);
- Contracts.CheckDecode(cch < int.MaxValue);
+ sb = sb != null ? sb.Clear() : new StringBuilder(assemblyNameLength);
- sb.Clear();
- for (long ich = 0; ich < cch; ich++)
+ for (long ich = 0; ich < assemblyNameLength; ich++)
sb.Append((char)reader.ReadUInt16());
- strings[i] = sb.ToString();
+
+ loaderAssemblyName = sb.ToString();
}
- Contracts.CheckDecode(offset == header.CbStringChars);
- Contracts.CheckDecode(header.FpStringChars + header.CbStringChars == reader.FpCur() - fpMin);
+ else
+ {
+ loaderAssemblyName = null;
+ }
+
Contracts.CheckDecode(header.FpTail == reader.FpCur() - fpMin);
ulong tail = reader.ReadUInt64();
Contracts.CheckDecode(tail == TailSignatureValue, "Corrupt model file tail");
- reader.Seek(fpOrig);
ex = null;
+
+ reader.Seek(fpOrig);
return true;
}
catch (Exception e)
{
strings = null;
+ loaderAssemblyName = null;
ex = e;
return false;
}
@@ -529,12 +609,13 @@ public static string GetLoaderSigAlt(ref ModelHeader header)
/// This is used to simplify version checking boiler-plate code. It is an optional
/// utility type.
///
- public struct VersionInfo
+ public readonly struct VersionInfo
{
public readonly ulong ModelSignature;
public readonly uint VerWrittenCur;
public readonly uint VerReadableCur;
public readonly uint VerWeCanReadBack;
+ public readonly string LoaderAssemblyName;
public readonly string LoaderSignature;
public readonly string LoaderSignatureAlt;
@@ -543,7 +624,7 @@ public struct VersionInfo
/// all less than 0x100. Spaces are mapped to zero. This assumes little-endian.
///
public VersionInfo(string modelSignature, uint verWrittenCur, uint verReadableCur, uint verWeCanReadBack,
- string loaderSignature = null, string loaderSignatureAlt = null)
+ string loaderAssemblyName, string loaderSignature = null, string loaderSignatureAlt = null)
{
Contracts.Check(Utils.Size(modelSignature) == 8, "Model signature must be eight characters");
ModelSignature = 0;
@@ -559,17 +640,7 @@ public VersionInfo(string modelSignature, uint verWrittenCur, uint verReadableCu
VerWrittenCur = verWrittenCur;
VerReadableCur = verReadableCur;
VerWeCanReadBack = verWeCanReadBack;
- LoaderSignature = loaderSignature;
- LoaderSignatureAlt = loaderSignatureAlt;
- }
-
- public VersionInfo(ulong modelSignature, uint verWrittenCur, uint verReadableCur, uint verWeCanReadBack,
- string loaderSignature = null, string loaderSignatureAlt = null)
- {
- ModelSignature = modelSignature;
- VerWrittenCur = verWrittenCur;
- VerReadableCur = verReadableCur;
- VerWeCanReadBack = verWeCanReadBack;
+ LoaderAssemblyName = loaderAssemblyName;
LoaderSignature = loaderSignature;
LoaderSignatureAlt = loaderSignatureAlt;
}
diff --git a/src/Microsoft.ML.Data/Model/ModelLoadContext.cs b/src/Microsoft.ML.Data/Model/ModelLoadContext.cs
index c09b7b09a2..7efd9f4b47 100644
--- a/src/Microsoft.ML.Data/Model/ModelLoadContext.cs
+++ b/src/Microsoft.ML.Data/Model/ModelLoadContext.cs
@@ -39,6 +39,14 @@ public sealed partial class ModelLoadContext : IDisposable
///
public readonly string[] Strings;
+ ///
+ /// The name of the assembly that the loader lives in.
+ ///
+ ///
+ /// This may be null or empty if one was never written to the model, or is an older model version.
+ ///
+ public readonly string LoaderAssemblyName;
+
///
/// The main stream's model header.
///
@@ -76,7 +84,7 @@ public ModelLoadContext(RepositoryReader rep, Repository.Entry ent, string dir)
Reader = new BinaryReader(ent.Stream, Encoding.UTF8, leaveOpen: true);
try
{
- ModelHeader.BeginRead(out FpMin, out Header, out Strings, Reader);
+ ModelHeader.BeginRead(out FpMin, out Header, out Strings, out LoaderAssemblyName, Reader);
}
catch
{
@@ -97,7 +105,7 @@ public ModelLoadContext(BinaryReader reader, IExceptionContext ectx = null)
Repository = null;
Directory = null;
Reader = reader;
- ModelHeader.BeginRead(out FpMin, out Header, out Strings, Reader);
+ ModelHeader.BeginRead(out FpMin, out Header, out Strings, out LoaderAssemblyName, Reader);
}
public void CheckAtModel()
diff --git a/src/Microsoft.ML.Data/Model/ModelLoading.cs b/src/Microsoft.ML.Data/Model/ModelLoading.cs
index c69461ea79..981da9e797 100644
--- a/src/Microsoft.ML.Data/Model/ModelLoading.cs
+++ b/src/Microsoft.ML.Data/Model/ModelLoading.cs
@@ -4,6 +4,7 @@
using System;
using System.IO;
+using System.Reflection;
using System.Text;
using Microsoft.ML.Runtime.Internal.Utilities;
@@ -212,6 +213,8 @@ private bool TryLoadModelCore(IHostEnvironment env, out TRes result,
var args = ConcatArgsRev(extra, this);
+ EnsureLoaderAssemblyIsRegistered(env.ComponentCatalog);
+
object tmp;
string sig = ModelHeader.GetLoaderSig(ref Header);
if (!string.IsNullOrWhiteSpace(sig) &&
@@ -246,6 +249,15 @@ private bool TryLoadModelCore(IHostEnvironment env, out TRes result,
return false;
}
+ private void EnsureLoaderAssemblyIsRegistered(ComponentCatalog catalog)
+ {
+ if (!string.IsNullOrEmpty(LoaderAssemblyName))
+ {
+ var assembly = Assembly.Load(LoaderAssemblyName);
+ catalog.RegisterAssembly(assembly);
+ }
+ }
+
private static object[] ConcatArgsRev(object[] args2, params object[] args1)
{
Contracts.AssertNonEmpty(args1);
diff --git a/src/Microsoft.ML.Data/Model/ModelSaveContext.cs b/src/Microsoft.ML.Data/Model/ModelSaveContext.cs
index 5e32893e7d..c5a9199758 100644
--- a/src/Microsoft.ML.Data/Model/ModelSaveContext.cs
+++ b/src/Microsoft.ML.Data/Model/ModelSaveContext.cs
@@ -24,7 +24,7 @@ public sealed partial class ModelSaveContext : IDisposable
public readonly RepositoryWriter Repository;
///
- /// When in repository mode, this is the direcory we're reading from. Null means the root
+ /// When in repository mode, this is the directory we're reading from. Null means the root
/// of the repository. It is always null in single-stream mode.
///
public readonly string Directory;
@@ -59,6 +59,11 @@ public sealed partial class ModelSaveContext : IDisposable
///
private readonly IExceptionContext _ectx;
+ ///
+ /// The assembly name where the loader resides.
+ ///
+ private string _loaderAssemblyName;
+
///
/// Returns whether this context is in repository mode (true) or single-stream mode (false).
///
@@ -131,6 +136,7 @@ public void CheckAtModel()
public void SetVersionInfo(VersionInfo ver)
{
ModelHeader.SetVersionInfo(ref Header, ver);
+ _loaderAssemblyName = ver.LoaderAssemblyName;
}
public void SaveTextStream(string name, Action action)
@@ -212,7 +218,7 @@ public void SaveNonEmptyString(ReadOnlyMemory str)
public void Done()
{
_ectx.Check(Header.ModelSignature != 0, "ModelSignature not specified!");
- ModelHeader.EndWrite(Writer, FpMin, ref Header, Strings);
+ ModelHeader.EndWrite(Writer, FpMin, ref Header, Strings, _loaderAssemblyName);
Dispose();
}
diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs
index 6cee31b9d2..310774a8a5 100644
--- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs
+++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs
@@ -332,7 +332,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(CalibratedPredictor).Assembly.FullName);
}
private static VersionInfo GetVersionInfoBulk()
{
@@ -341,7 +342,8 @@ private static VersionInfo GetVersionInfoBulk()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(CalibratedPredictor).Assembly.FullName);
}
private CalibratedPredictor(IHostEnvironment env, ModelLoadContext ctx)
@@ -393,7 +395,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(FeatureWeightsCalibratedPredictor).Assembly.FullName);
}
private FeatureWeightsCalibratedPredictor(IHostEnvironment env, ModelLoadContext ctx)
@@ -455,7 +458,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(ParameterMixingCalibratedPredictor).Assembly.FullName);
}
private ParameterMixingCalibratedPredictor(IHostEnvironment env, ModelLoadContext ctx)
@@ -608,7 +612,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(SchemaBindableCalibratedPredictor).Assembly.FullName);
}
///
@@ -967,7 +972,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(NaiveCalibrator).Assembly.FullName);
}
private readonly IHost _host;
@@ -1334,7 +1340,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(PlattCalibrator).Assembly.FullName);
}
private readonly IHost _host;
@@ -1569,7 +1576,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(PavCalibrator).Assembly.FullName);
}
// Epsilon for 0-comparisons
diff --git a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs
index ee76b5a674..dc9a5bbdf9 100644
--- a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs
+++ b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs
@@ -40,7 +40,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010004, // ISchemaBindableMapper update
verReadableCur: 0x00010004,
verWeCanReadBack: 0x00010004,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(BinaryClassifierScorer).Assembly.FullName);
}
private const string RegistrationName = "BinaryClassifierScore";
diff --git a/src/Microsoft.ML.Data/Scorers/ClusteringScorer.cs b/src/Microsoft.ML.Data/Scorers/ClusteringScorer.cs
index 7a065fe4d4..26eae0154d 100644
--- a/src/Microsoft.ML.Data/Scorers/ClusteringScorer.cs
+++ b/src/Microsoft.ML.Data/Scorers/ClusteringScorer.cs
@@ -37,7 +37,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010003, // ISchemaBindableMapper update
verReadableCur: 0x00010003,
verWeCanReadBack: 0x00010003,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(ClusteringScorer).Assembly.FullName);
}
private const string RegistrationName = "ClusteringScore";
diff --git a/src/Microsoft.ML.Data/Scorers/GenericScorer.cs b/src/Microsoft.ML.Data/Scorers/GenericScorer.cs
index 41c12e94ed..0b261502ed 100644
--- a/src/Microsoft.ML.Data/Scorers/GenericScorer.cs
+++ b/src/Microsoft.ML.Data/Scorers/GenericScorer.cs
@@ -130,7 +130,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(GenericScorer).Assembly.FullName);
}
private const string RegistrationName = "GenericScore";
diff --git a/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs
index c12fd9b4d1..79387c2727 100644
--- a/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs
+++ b/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs
@@ -48,7 +48,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010003, // ISchemaBindableMapper update
verReadableCur: 0x00010003,
verWeCanReadBack: 0x00010003,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(MultiClassClassifierScorer).Assembly.FullName);
}
private const string RegistrationName = "MultiClassClassifierScore";
@@ -88,7 +89,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010002, // Added metadataKind
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(LabelNameBindableMapper).Assembly.FullName);
}
private const int VersionAddedMetadataKind = 0x00010002;
diff --git a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs
index aa71ac2dc5..904b591dd4 100644
--- a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs
+++ b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs
@@ -256,7 +256,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: BinaryPredictionTransformer.LoaderSignature);
+ loaderSignature: BinaryPredictionTransformer.LoaderSignature,
+ loaderAssemblyName: typeof(BinaryPredictionTransformer<>).Assembly.FullName);
}
}
@@ -321,7 +322,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: MulticlassPredictionTransformer.LoaderSignature);
+ loaderSignature: MulticlassPredictionTransformer.LoaderSignature,
+ loaderAssemblyName: typeof(MulticlassPredictionTransformer<>).Assembly.FullName);
}
}
@@ -371,7 +373,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: RegressionPredictionTransformer.LoaderSignature);
+ loaderSignature: RegressionPredictionTransformer.LoaderSignature,
+ loaderAssemblyName: typeof(RegressionPredictionTransformer<>).Assembly.FullName);
}
}
@@ -417,7 +420,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: RankingPredictionTransformer.LoaderSignature);
+ loaderSignature: RankingPredictionTransformer.LoaderSignature,
+ loaderAssemblyName: typeof(RankingPredictionTransformer<>).Assembly.FullName);
}
}
diff --git a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs
index 31b921a064..8422d5042c 100644
--- a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs
+++ b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs
@@ -242,7 +242,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010002, // ISchemaBindableWrapper update
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010002,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(SchemaBindablePredictorWrapper).Assembly.FullName);
}
private readonly string _scoreColumnKind;
@@ -353,7 +354,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010002, // ISchemaBindableWrapper update
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010002,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(SchemaBindableBinaryPredictorWrapper).Assembly.FullName);
}
private readonly IValueMapperDist _distMapper;
@@ -581,7 +583,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010002, // ISchemaBindableWrapper update
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010002,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(SchemaBindableQuantileRegressionPredictor).Assembly.FullName);
}
private readonly IQuantileValueMapper _qpred;
diff --git a/src/Microsoft.ML.Data/Transforms/ChooseColumnsTransform.cs b/src/Microsoft.ML.Data/Transforms/ChooseColumnsTransform.cs
index 1459f55cab..e1300e38cd 100644
--- a/src/Microsoft.ML.Data/Transforms/ChooseColumnsTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/ChooseColumnsTransform.cs
@@ -449,7 +449,8 @@ private static VersionInfo GetVersionInfo()
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
- loaderSignatureAlt: LoaderSignatureOld);
+ loaderSignatureAlt: LoaderSignatureOld,
+ loaderAssemblyName: typeof(ChooseColumnsTransform).Assembly.FullName);
}
private readonly Bindings _bindings;
diff --git a/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs b/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs
index d827947192..1ec6c13d61 100644
--- a/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs
@@ -243,7 +243,8 @@ private static VersionInfo GetVersionInfo()
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
- loaderSignatureAlt: LoaderSignatureOld);
+ loaderSignatureAlt: LoaderSignatureOld,
+ loaderAssemblyName: typeof(ConcatTransform).Assembly.FullName);
}
private const int VersionAddedAliases = 0x00010002;
diff --git a/src/Microsoft.ML.Data/Transforms/ConvertTransform.cs b/src/Microsoft.ML.Data/Transforms/ConvertTransform.cs
index eec59c8a8f..62906ffa55 100644
--- a/src/Microsoft.ML.Data/Transforms/ConvertTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/ConvertTransform.cs
@@ -161,7 +161,8 @@ private static VersionInfo GetVersionInfo()
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
- loaderSignatureAlt: LoaderSignatureOld);
+ loaderSignatureAlt: LoaderSignatureOld,
+ loaderAssemblyName: typeof(ConvertTransform).Assembly.FullName);
}
private const string RegistrationName = "Convert";
diff --git a/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs b/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs
index 478d509c24..8ca1260ebf 100644
--- a/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs
@@ -93,7 +93,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(CopyColumnsTransform).Assembly.FullName);
}
public CopyColumnsTransform(IHostEnvironment env, params (string source, string name)[] columns)
@@ -224,7 +225,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(CopyColumnsRowMapper).Assembly.FullName);
}
// Factory method for SignatureLoadRowMapper.
diff --git a/src/Microsoft.ML.Data/Transforms/DropColumnsTransform.cs b/src/Microsoft.ML.Data/Transforms/DropColumnsTransform.cs
index 3e15199ff7..aa5026a914 100644
--- a/src/Microsoft.ML.Data/Transforms/DropColumnsTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/DropColumnsTransform.cs
@@ -229,7 +229,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010002, // Added KeepColumns
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010002,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(DropColumnsTransform).Assembly.FullName);
}
private readonly Bindings _bindings;
diff --git a/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs b/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs
index 9945ee6cc3..6f8185b75a 100644
--- a/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs
@@ -188,7 +188,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(DropSlotsTransform).Assembly.FullName);
}
private const string RegistrationName = "DropSlots";
diff --git a/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs b/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs
index c141bb66c1..11adab9ab6 100644
--- a/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs
@@ -249,7 +249,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(GenerateNumberTransform).Assembly.FullName);
}
private readonly Bindings _bindings;
diff --git a/src/Microsoft.ML.Data/Transforms/HashTransform.cs b/src/Microsoft.ML.Data/Transforms/HashTransform.cs
index a589121be6..6ad7f68cbf 100644
--- a/src/Microsoft.ML.Data/Transforms/HashTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/HashTransform.cs
@@ -196,7 +196,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010002, // Invert hash key values, hash fix
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010002,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(HashTransformer).Assembly.FullName);
}
private readonly ColumnInfo[] _columns;
diff --git a/src/Microsoft.ML.Data/Transforms/InvertHashUtils.cs b/src/Microsoft.ML.Data/Transforms/InvertHashUtils.cs
index db6ff54298..4c90bc1d9d 100644
--- a/src/Microsoft.ML.Data/Transforms/InvertHashUtils.cs
+++ b/src/Microsoft.ML.Data/Transforms/InvertHashUtils.cs
@@ -329,7 +329,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(TextModelHelper).Assembly.FullName);
}
private static void Load(IChannel ch, ModelLoadContext ctx, CodecFactory factory, ref VBuffer> values)
diff --git a/src/Microsoft.ML.Data/Transforms/KeyToValueTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToValueTransform.cs
index de8463a804..d9e75fb178 100644
--- a/src/Microsoft.ML.Data/Transforms/KeyToValueTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/KeyToValueTransform.cs
@@ -76,7 +76,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(KeyToValueTransform).Assembly.FullName);
}
///
diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs
index bda8729d4a..fd177d6a14 100644
--- a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs
@@ -143,7 +143,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010002, // Get rid of writing float size in model context
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(KeyToVectorTransform).Assembly.FullName);
}
public override void Save(ModelSaveContext ctx)
diff --git a/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs b/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs
index 8817833f40..b5943b14fa 100644
--- a/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs
@@ -58,7 +58,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial.
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(LabelConvertTransform).Assembly.FullName);
}
private const string RegistrationName = "LabelConvert";
diff --git a/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs b/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs
index 81882ac749..5844d61d31 100644
--- a/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs
@@ -39,7 +39,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(LabelIndicatorTransform).Assembly.FullName);
}
public sealed class Column : OneToOneColumn
diff --git a/src/Microsoft.ML.Data/Transforms/NAFilter.cs b/src/Microsoft.ML.Data/Transforms/NAFilter.cs
index c8515291f3..ed738f3e63 100644
--- a/src/Microsoft.ML.Data/Transforms/NAFilter.cs
+++ b/src/Microsoft.ML.Data/Transforms/NAFilter.cs
@@ -70,7 +70,8 @@ private static VersionInfo GetVersionInfo()
loaderSignature: LoaderSignature,
// This is an older name and can be removed once we don't care about old code
// being able to load this.
- loaderSignatureAlt: "MissingFeatureFilter");
+ loaderSignatureAlt: "MissingFeatureFilter",
+ loaderAssemblyName: typeof(NAFilter).Assembly.FullName);
}
private readonly ColInfo[] _infos;
diff --git a/src/Microsoft.ML.Data/Transforms/NopTransform.cs b/src/Microsoft.ML.Data/Transforms/NopTransform.cs
index 0a71a3ec07..a367073f2f 100644
--- a/src/Microsoft.ML.Data/Transforms/NopTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/NopTransform.cs
@@ -55,7 +55,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(NopTransform).Assembly.FullName);
}
internal static string RegistrationName = "NopTransform";
diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs
index b515b2293c..a91c4645eb 100644
--- a/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs
+++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs
@@ -615,7 +615,8 @@ public static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(CdfColumnFunction).Assembly.FullName);
}
}
@@ -677,7 +678,8 @@ public static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(BinColumnFunction).Assembly.FullName);
}
}
@@ -1139,7 +1141,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010003, // Scales multiply instead of divide
verReadableCur: 0x00010003,
verWeCanReadBack: 0x00010003,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(AffineNormSerializationUtils).Assembly.FullName);
}
}
@@ -1154,7 +1157,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001,
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(BinNormSerializationUtils).Assembly.FullName);
}
}
diff --git a/src/Microsoft.ML.Data/Transforms/Normalizer.cs b/src/Microsoft.ML.Data/Transforms/Normalizer.cs
index 7cf6bd3d31..4c3d9fd51d 100644
--- a/src/Microsoft.ML.Data/Transforms/Normalizer.cs
+++ b/src/Microsoft.ML.Data/Transforms/Normalizer.cs
@@ -216,7 +216,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(NormalizerTransformer).Assembly.FullName);
}
private class ColumnInfo
diff --git a/src/Microsoft.ML.Data/Transforms/RangeFilter.cs b/src/Microsoft.ML.Data/Transforms/RangeFilter.cs
index 142779dee2..d20446f7f5 100644
--- a/src/Microsoft.ML.Data/Transforms/RangeFilter.cs
+++ b/src/Microsoft.ML.Data/Transforms/RangeFilter.cs
@@ -64,7 +64,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(RangeFilter).Assembly.FullName);
}
private const string RegistrationName = "RangeFilter";
diff --git a/src/Microsoft.ML.Data/Transforms/ShuffleTransform.cs b/src/Microsoft.ML.Data/Transforms/ShuffleTransform.cs
index 3940bbe979..4043e3c42b 100644
--- a/src/Microsoft.ML.Data/Transforms/ShuffleTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/ShuffleTransform.cs
@@ -71,7 +71,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010002, // Force shuffle source saving
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010002,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(ShuffleTransform).Assembly.FullName);
}
private const string RegistrationName = "Shuffle";
diff --git a/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs b/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs
index 2adb17258e..440d68ae84 100644
--- a/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs
+++ b/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs
@@ -76,7 +76,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(SkipTakeFilter).Assembly.FullName);
}
private readonly long _skip;
diff --git a/src/Microsoft.ML.Data/Transforms/TermTransform.cs b/src/Microsoft.ML.Data/Transforms/TermTransform.cs
index 55db806941..c3084a9fa7 100644
--- a/src/Microsoft.ML.Data/Transforms/TermTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/TermTransform.cs
@@ -193,7 +193,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010003, // Generalize to multiple types beyond text
verReadableCur: 0x00010003,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(TermTransform).Assembly.FullName);
}
private const uint VerNonTextTypesSupported = 0x00010003;
@@ -224,7 +225,8 @@ private static VersionInfo GetTermManagerVersionInfo()
verWrittenCur: 0x00010002, // Generalize to multiple types beyond text
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010001,
- loaderSignature: TermManagerLoaderSignature);
+ loaderSignature: TermManagerLoaderSignature,
+ loaderAssemblyName: typeof(TermTransform).Assembly.FullName);
}
private readonly TermMap[] _unboundMaps;
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Average.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Average.cs
index 45cd764d13..de1e5ef505 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/Average.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Average.cs
@@ -25,7 +25,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001,
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(Average).Assembly.FullName);
}
public Average(IHostEnvironment env)
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Median.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Median.cs
index de8d950de4..95bc0cc991 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/Median.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Median.cs
@@ -30,7 +30,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001,
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(Median).Assembly.FullName);
}
public Median(IHostEnvironment env)
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiAverage.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiAverage.cs
index c147f932f3..fef6fa087e 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiAverage.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiAverage.cs
@@ -28,7 +28,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001,
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(MultiAverage).Assembly.FullName);
}
[TlcModule.Component(Name = LoadName, FriendlyName = Average.UserName)]
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiMedian.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiMedian.cs
index c3e6869d69..86312393de 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiMedian.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiMedian.cs
@@ -31,7 +31,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001,
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(MultiMedian).Assembly.FullName);
}
[TlcModule.Component(Name = LoadName, FriendlyName = Median.UserName)]
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs
index 67af075f7d..fd6af4dc53 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs
@@ -35,7 +35,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(MultiStacking).Assembly.FullName);
}
[TlcModule.Component(Name = LoadName, FriendlyName = Stacking.UserName)]
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs
index ee55b94c77..0c68f70287 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs
@@ -29,7 +29,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001,
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(MultiVoting).Assembly.FullName);
}
private sealed class Arguments : ArgumentsBase
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiWeightedAverage.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiWeightedAverage.cs
index 9bda1d151a..a2f52b1451 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiWeightedAverage.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiWeightedAverage.cs
@@ -35,7 +35,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001,
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(MultiWeightedAverage).Assembly.FullName);
}
[TlcModule.Component(Name = LoadName, FriendlyName = UserName)]
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs
index fa68546eb8..f1503ed23d 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs
@@ -33,7 +33,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(RegressionStacking).Assembly.FullName);
}
[TlcModule.Component(Name = LoadName, FriendlyName = Stacking.UserName)]
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs
index 63adbd7c56..9a569cb2af 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs
@@ -31,7 +31,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(Stacking).Assembly.FullName);
}
[TlcModule.Component(Name = LoadName, FriendlyName = UserName)]
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Voting.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Voting.cs
index 932f99d93a..d352439d55 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/Voting.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Voting.cs
@@ -27,7 +27,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001,
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(Voting).Assembly.FullName);
}
public Voting(IHostEnvironment env)
diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/WeightedAverage.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/WeightedAverage.cs
index 8b16ffd0a2..e11f0c9bb9 100644
--- a/src/Microsoft.ML.Ensemble/OutputCombiners/WeightedAverage.cs
+++ b/src/Microsoft.ML.Ensemble/OutputCombiners/WeightedAverage.cs
@@ -32,7 +32,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001,
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(WeightedAverage).Assembly.FullName);
}
[TlcModule.Component(Name = LoadName, FriendlyName = UserName)]
diff --git a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs
index e4d6cb0d9e..a13903f1a3 100644
--- a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs
+++ b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs
@@ -377,7 +377,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010002, // Save predictor models in a subdirectory
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(SchemaBindablePipelineEnsembleBase).Assembly.FullName);
}
public const string UserName = "Pipeline Ensemble";
public const string LoaderSignature = "PipelineEnsemble";
diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionPredictor.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionPredictor.cs
index 547800c152..a15563a289 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionPredictor.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionPredictor.cs
@@ -37,7 +37,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010003, // Don't serialize the "IsAveraged" property of the metrics
verReadableCur: 0x00010003,
verWeCanReadBack: 0x00010002,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(EnsembleDistributionPredictor).Assembly.FullName);
}
private readonly Single[] _averagedWeights;
diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictor.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictor.cs
index 08c8f0dd8d..1f2ed87bf6 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictor.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictor.cs
@@ -35,7 +35,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010003, // Don't serialize the "IsAveraged" property of the metrics
verReadableCur: 0x00010003,
verWeCanReadBack: 0x00010002,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(EnsemblePredictor).Assembly.FullName);
}
private readonly IValueMapper[] _mappers;
diff --git a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassPredictor.cs b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassPredictor.cs
index 558d0afd6e..4dfaf3983a 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassPredictor.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMultiClassPredictor.cs
@@ -32,7 +32,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010003, // Don't serialize the "IsAveraged" property of the metrics
verReadableCur: 0x00010003,
verWeCanReadBack: 0x00010002,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(EnsembleMultiClassPredictor).Assembly.FullName);
}
private readonly ColumnType _inputType;
diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs
index 0d9d5bc192..ca103f3c25 100644
--- a/src/Microsoft.ML.FastTree/FastTree.cs
+++ b/src/Microsoft.ML.FastTree/FastTree.cs
@@ -819,7 +819,7 @@ protected virtual void PrintPrologInfo(IChannel ch)
{
Contracts.AssertValue(ch);
ch.Trace("Host = {0}", Environment.MachineName);
- ch.Trace("CommandLine = {0}", CmdParser.GetSettings(ch, Args, new TArgs()));
+ ch.Trace("CommandLine = {0}", CmdParser.GetSettings(Host, Args, new TArgs()));
ch.Trace("GCSettings.IsServerGC = {0}", System.Runtime.GCSettings.IsServerGC);
ch.Trace("{0}", Args);
}
diff --git a/src/Microsoft.ML.FastTree/FastTreeClassification.cs b/src/Microsoft.ML.FastTree/FastTreeClassification.cs
index 97cc1c43f9..8537555ec9 100644
--- a/src/Microsoft.ML.FastTree/FastTreeClassification.cs
+++ b/src/Microsoft.ML.FastTree/FastTreeClassification.cs
@@ -58,7 +58,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010005, //Categorical splits.
verReadableCur: 0x00010005,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(FastTreeBinaryPredictor).Assembly.FullName);
}
protected override uint VerNumFeaturesSerialized => 0x00010002;
diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs
index 90331c47b4..0c8419846a 100644
--- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs
+++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs
@@ -1100,7 +1100,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010005, // Categorical splits.
verReadableCur: 0x00010004,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(FastTreeRankingPredictor).Assembly.FullName);
}
protected override uint VerNumFeaturesSerialized => 0x00010002;
diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs
index 717824b6f5..fde54781d9 100644
--- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs
+++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs
@@ -447,7 +447,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010005, // Categorical splits.
verReadableCur: 0x00010004,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(FastTreeRegressionPredictor).Assembly.FullName);
}
protected override uint VerNumFeaturesSerialized => 0x00010002;
diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs
index eedbfe7389..c028ef963a 100644
--- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs
+++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs
@@ -442,7 +442,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010003, // Categorical splits.
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(FastTreeTweediePredictor).Assembly.FullName);
}
protected override uint VerNumFeaturesSerialized => 0x00010001;
diff --git a/src/Microsoft.ML.FastTree/GamClassification.cs b/src/Microsoft.ML.FastTree/GamClassification.cs
index 1448333764..184cbdc565 100644
--- a/src/Microsoft.ML.FastTree/GamClassification.cs
+++ b/src/Microsoft.ML.FastTree/GamClassification.cs
@@ -132,7 +132,8 @@ public static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001,
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(BinaryClassGamPredictor).Assembly.FullName);
}
public static IPredictorProducing Create(IHostEnvironment env, ModelLoadContext ctx)
diff --git a/src/Microsoft.ML.FastTree/GamRegression.cs b/src/Microsoft.ML.FastTree/GamRegression.cs
index 119e8c2a85..e55a3ee008 100644
--- a/src/Microsoft.ML.FastTree/GamRegression.cs
+++ b/src/Microsoft.ML.FastTree/GamRegression.cs
@@ -87,7 +87,8 @@ public static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001,
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(RegressionGamPredictor).Assembly.FullName);
}
public static RegressionGamPredictor Create(IHostEnvironment env, ModelLoadContext ctx)
diff --git a/src/Microsoft.ML.FastTree/GamTrainer.cs b/src/Microsoft.ML.FastTree/GamTrainer.cs
index 0deceb1aff..43bbb60e15 100644
--- a/src/Microsoft.ML.FastTree/GamTrainer.cs
+++ b/src/Microsoft.ML.FastTree/GamTrainer.cs
@@ -956,11 +956,11 @@ public sealed class Arguments : DataCommand.ArgumentsBase
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether to open the GAM visualization page URL", ShortName = "o", SortOrder = 3)]
public bool Open = true;
- internal Arguments SetServerIfNeeded(IExceptionContext ectx)
+ internal Arguments SetServerIfNeeded(IHostEnvironment env)
{
// We assume that if someone invoked this, they really did mean to start the web server.
- if (ectx != null && Server == null)
- Server = ServerChannel.CreateDefaultServerFactoryOrNull(ectx);
+ if (env != null && Server == null)
+ Server = ServerChannel.CreateDefaultServerFactoryOrNull(env);
return this;
}
}
diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs
index b5ba1f43e7..fe79ab5d46 100644
--- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs
+++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs
@@ -63,7 +63,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010006, // Categorical splits.
verReadableCur: 0x00010005,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(FastForestClassificationPredictor).Assembly.FullName);
}
protected override uint VerNumFeaturesSerialized => 0x00010003;
diff --git a/src/Microsoft.ML.FastTree/RandomForestRegression.cs b/src/Microsoft.ML.FastTree/RandomForestRegression.cs
index 1261cd0e37..68294596e4 100644
--- a/src/Microsoft.ML.FastTree/RandomForestRegression.cs
+++ b/src/Microsoft.ML.FastTree/RandomForestRegression.cs
@@ -49,7 +49,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010006, // Categorical splits.
verReadableCur: 0x00010005,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(FastForestRegressionPredictor).Assembly.FullName);
}
protected override uint VerNumFeaturesSerialized => 0x00010003;
diff --git a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs
index fcd2550147..6faeba3cd8 100644
--- a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs
+++ b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs
@@ -418,7 +418,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010002, // Add _defaultValueForMissing
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(TreeEnsembleFeaturizerBindableMapper).Assembly.FullName);
}
private readonly IHost _host;
diff --git a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs
index cb7472633f..879a32b95f 100644
--- a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs
+++ b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs
@@ -500,7 +500,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001,
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(OlsLinearRegressionPredictor).Assembly.FullName);
}
// The following will be null iff RSquaredAdjusted is NaN.
diff --git a/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs
index 5744fe384d..a137dcfd6a 100644
--- a/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs
+++ b/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs
@@ -74,7 +74,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(ImageGrayscaleTransform).Assembly.FullName);
}
private const string RegistrationName = "ImageGrayscale";
diff --git a/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs
index 067cd30747..90e3fd6b3e 100644
--- a/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs
+++ b/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs
@@ -141,7 +141,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010002, // Swith from OpenCV to Bitmap
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010002,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(ImageLoaderTransform).Assembly.FullName);
}
protected override IRowMapper MakeRowMapper(ISchema schema)
diff --git a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs
index 1e563fbc2c..c2cd0f1a83 100644
--- a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs
+++ b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs
@@ -297,7 +297,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010002, // Swith from OpenCV to Bitmap
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010002,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(ImagePixelExtractorTransform).Assembly.FullName);
}
private const string RegistrationName = "ImagePixelExtractor";
diff --git a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs
index 1755430cf3..2de1354205 100644
--- a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs
+++ b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs
@@ -154,7 +154,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010003, // No more sizeof(float)
verReadableCur: 0x00010003,
verWeCanReadBack: 0x00010003,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(ImageResizerTransform).Assembly.FullName);
}
private const string RegistrationName = "ImageScaler";
diff --git a/src/Microsoft.ML.ImageAnalytics/VectorToImageTransform.cs b/src/Microsoft.ML.ImageAnalytics/VectorToImageTransform.cs
index 2446b7a7b6..a6f69204c4 100644
--- a/src/Microsoft.ML.ImageAnalytics/VectorToImageTransform.cs
+++ b/src/Microsoft.ML.ImageAnalytics/VectorToImageTransform.cs
@@ -236,7 +236,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010002, // Swith from OpenCV to Bitmap
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010002,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(VectorToImageTransform).Assembly.FullName);
}
private const string RegistrationName = "VectorToImageConverter";
diff --git a/src/Microsoft.ML.KMeansClustering/KMeansPredictor.cs b/src/Microsoft.ML.KMeansClustering/KMeansPredictor.cs
index e3249da34d..e1b54f472a 100644
--- a/src/Microsoft.ML.KMeansClustering/KMeansPredictor.cs
+++ b/src/Microsoft.ML.KMeansClustering/KMeansPredictor.cs
@@ -38,7 +38,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010002, // Allow sparse centroids
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(KMeansPredictor).Assembly.FullName);
}
public override PredictionKind PredictionKind => PredictionKind.Clustering;
diff --git a/src/Microsoft.ML.Legacy/Microsoft.ML.Legacy.csproj b/src/Microsoft.ML.Legacy/Microsoft.ML.Legacy.csproj
index bd3d72ab68..a2eeeb046c 100644
--- a/src/Microsoft.ML.Legacy/Microsoft.ML.Legacy.csproj
+++ b/src/Microsoft.ML.Legacy/Microsoft.ML.Legacy.csproj
@@ -6,6 +6,10 @@
CORECLR
+
+
+
+
diff --git a/src/Microsoft.ML.Legacy/PredictionModel.cs b/src/Microsoft.ML.Legacy/PredictionModel.cs
index 45c4738fe2..29f1bf35e9 100644
--- a/src/Microsoft.ML.Legacy/PredictionModel.cs
+++ b/src/Microsoft.ML.Legacy/PredictionModel.cs
@@ -127,6 +127,8 @@ public static Task> ReadAsync(
using (var environment = new ConsoleEnvironment())
{
+ AssemblyLoadingUtils.RegisterCurrentLoadedAssemblies(environment);
+
BatchPredictionEngine predictor =
environment.CreateBatchPredictionEngine(stream);
diff --git a/src/Microsoft.ML.Legacy/Runtime/Experiment/Experiment.cs b/src/Microsoft.ML.Legacy/Runtime/Experiment/Experiment.cs
index 108befb74b..efb9cb33d0 100644
--- a/src/Microsoft.ML.Legacy/Runtime/Experiment/Experiment.cs
+++ b/src/Microsoft.ML.Legacy/Runtime/Experiment/Experiment.cs
@@ -37,6 +37,8 @@ private sealed class SerializationHelper
public Experiment(Runtime.IHostEnvironment env)
{
_env = env;
+ AssemblyLoadingUtils.RegisterCurrentLoadedAssemblies(_env);
+
_catalog = ModuleCatalog.CreateInstance(_env);
_jsonNodes = new List();
_serializer = new JsonSerializer();
diff --git a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs
index f788b4feab..6fe77564b7 100644
--- a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs
+++ b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs
@@ -40,7 +40,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010005, // Categorical splits.
verReadableCur: 0x00010004,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(LightGbmBinaryPredictor).Assembly.FullName);
}
protected override uint VerNumFeaturesSerialized => 0x00010002;
diff --git a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs
index 3fe4628182..7373fbbac9 100644
--- a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs
+++ b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs
@@ -38,7 +38,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010005, // Categorical splits.
verReadableCur: 0x00010004,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(LightGbmRankingPredictor).Assembly.FullName);
}
protected override uint VerNumFeaturesSerialized => 0x00010002;
diff --git a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs
index 0011a8d8e6..08331cb45f 100644
--- a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs
+++ b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs
@@ -38,7 +38,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010005, // Categorical splits.
verReadableCur: 0x00010004,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(LightGbmRegressionPredictor).Assembly.FullName);
}
protected override uint VerNumFeaturesSerialized => 0x00010002;
diff --git a/src/Microsoft.ML.Maml/HelpCommand.cs b/src/Microsoft.ML.Maml/HelpCommand.cs
index 401cb7dd10..6c6d800ac7 100644
--- a/src/Microsoft.ML.Maml/HelpCommand.cs
+++ b/src/Microsoft.ML.Maml/HelpCommand.cs
@@ -100,7 +100,7 @@ public void Run()
public void Run(int? columns)
{
- ComponentCatalog.CacheClassesExtra(_extraAssemblies);
+ AssemblyLoadingUtils.LoadAndRegister(_env, _extraAssemblies);
using (var ch = _env.Start("Help"))
using (var sw = new StringWriter(CultureInfo.InvariantCulture))
@@ -137,7 +137,7 @@ private void ShowHelp(IndentingTextWriter writer, int? columns = null)
// Note that we don't check IsHidden here. The current policy is when IsHidden is true, we don't
// show the item in "list all" functionality, but will still show help when explicitly requested.
- var infos = ComponentCatalog.FindLoadableClasses(name)
+ var infos = _env.ComponentCatalog.FindLoadableClasses(name)
.OrderBy(x => ComponentCatalog.SignatureToString(x.SignatureTypes[0]).ToLowerInvariant());
var kinds = new StringBuilder();
var components = new List();
@@ -188,7 +188,7 @@ private void ShowAllHelp(IndentingTextWriter writer, int? columns = null)
{
string sig = _kind?.ToLowerInvariant();
- var infos = ComponentCatalog.GetAllClasses()
+ var infos = _env.ComponentCatalog.GetAllClasses()
.OrderBy(info => info.LoadNames[0].ToLowerInvariant())
.ThenBy(info => ComponentCatalog.SignatureToString(info.SignatureTypes[0]).ToLowerInvariant());
var components = new List();
@@ -256,7 +256,7 @@ private void ShowComponents(IndentingTextWriter writer)
else
{
kind = _kind.ToLowerInvariant();
- var sigs = ComponentCatalog.GetAllSignatureTypes();
+ var sigs = _env.ComponentCatalog.GetAllSignatureTypes();
typeSig = sigs.FirstOrDefault(t => ComponentCatalog.SignatureToString(t).ToLowerInvariant() == kind);
if (typeSig == null)
{
@@ -272,7 +272,7 @@ private void ShowComponents(IndentingTextWriter writer)
writer.WriteLine("Available components for kind '{0}':", ComponentCatalog.SignatureToString(typeSig));
}
- var infos = ComponentCatalog.GetAllDerivedClasses(typeRes, typeSig)
+ var infos = _env.ComponentCatalog.GetAllDerivedClasses(typeRes, typeSig)
.Where(x => !x.IsHidden)
.OrderBy(x => x.LoadNames[0].ToLowerInvariant());
using (writer.Nest())
@@ -322,7 +322,7 @@ private void ShowAliases(IndentingTextWriter writer, IReadOnlyList names
private void ListKinds(IndentingTextWriter writer)
{
- var sigs = ComponentCatalog.GetAllSignatureTypes()
+ var sigs = _env.ComponentCatalog.GetAllSignatureTypes()
.Select(ComponentCatalog.SignatureToString)
.OrderBy(x => x);
diff --git a/src/Microsoft.ML.Maml/MAML.cs b/src/Microsoft.ML.Maml/MAML.cs
index cac407c21a..c88ee6112c 100644
--- a/src/Microsoft.ML.Maml/MAML.cs
+++ b/src/Microsoft.ML.Maml/MAML.cs
@@ -55,7 +55,10 @@ public static int Main()
private static int MainWithProgress(string args)
{
+ string currentDirectory = Path.GetDirectoryName(typeof(Maml).Module.FullyQualifiedName);
+
using (var env = CreateEnvironment())
+ using (AssemblyLoadingUtils.CreateAssemblyRegistrar(env, currentDirectory))
using (var progressCancel = new CancellationTokenSource())
{
var progressTrackerTask = Task.Run(() => TrackProgress(env, progressCancel.Token));
diff --git a/src/Microsoft.ML.Maml/Microsoft.ML.Maml.csproj b/src/Microsoft.ML.Maml/Microsoft.ML.Maml.csproj
index 88b1e9e55b..219b6bf0b8 100644
--- a/src/Microsoft.ML.Maml/Microsoft.ML.Maml.csproj
+++ b/src/Microsoft.ML.Maml/Microsoft.ML.Maml.csproj
@@ -3,10 +3,14 @@
true
CORECLR
- Microsoft.ML
- netstandard2.0
+ Microsoft.ML
+ netstandard2.0
+
+
+
+
diff --git a/src/Microsoft.ML.PCA/PcaTrainer.cs b/src/Microsoft.ML.PCA/PcaTrainer.cs
index 2654af5d13..b784f1f6ee 100644
--- a/src/Microsoft.ML.PCA/PcaTrainer.cs
+++ b/src/Microsoft.ML.PCA/PcaTrainer.cs
@@ -307,7 +307,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(PcaPredictor).Assembly.FullName);
}
private readonly int _dimension;
diff --git a/src/Microsoft.ML.PCA/PcaTransform.cs b/src/Microsoft.ML.PCA/PcaTransform.cs
index abb4b9b821..dd1596f330 100644
--- a/src/Microsoft.ML.PCA/PcaTransform.cs
+++ b/src/Microsoft.ML.PCA/PcaTransform.cs
@@ -184,7 +184,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(PcaTransform).Assembly.FullName);
}
// These are parallel to Infos.
diff --git a/src/Microsoft.ML.Parquet/ParquetLoader.cs b/src/Microsoft.ML.Parquet/ParquetLoader.cs
index e131949138..667600c44e 100644
--- a/src/Microsoft.ML.Parquet/ParquetLoader.cs
+++ b/src/Microsoft.ML.Parquet/ParquetLoader.cs
@@ -109,7 +109,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010002, // Add Schema to Model Context
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(ParquetLoader).Assembly.FullName);
}
public ParquetLoader(IHostEnvironment env, Arguments args, IMultiStreamSource files)
diff --git a/src/Microsoft.ML.PipelineInference/ExperimentsGenerator.cs b/src/Microsoft.ML.PipelineInference/ExperimentsGenerator.cs
index 053b86a40b..ee0fa11bde 100644
--- a/src/Microsoft.ML.PipelineInference/ExperimentsGenerator.cs
+++ b/src/Microsoft.ML.PipelineInference/ExperimentsGenerator.cs
@@ -110,7 +110,7 @@ public static List GenerateCandidates(IHostEnvironment env, string dataFi
//get all the trainers for this task, and generate the initial set of candidates.
// Exclude the hidden learners, and the metalinear learners.
- var trainers = ComponentCatalog.GetAllDerivedClasses(typeof(ITrainer), predictorType).Where(cls => !cls.IsHidden);
+ var trainers = env.ComponentCatalog.GetAllDerivedClasses(typeof(ITrainer), predictorType).Where(cls => !cls.IsHidden);
if (!string.IsNullOrEmpty(loaderSettings))
{
diff --git a/src/Microsoft.ML.PipelineInference/RecipeInference.cs b/src/Microsoft.ML.PipelineInference/RecipeInference.cs
index 67f3a7f02d..4363c8a4d7 100644
--- a/src/Microsoft.ML.PipelineInference/RecipeInference.cs
+++ b/src/Microsoft.ML.PipelineInference/RecipeInference.cs
@@ -128,13 +128,13 @@ public InferenceResult(SuggestedRecipe[] suggestedRecipes)
}
}
- private static IEnumerable GetRecipes()
+ private static IEnumerable GetRecipes(IHostEnvironment env)
{
yield return new DefaultRecipe();
- yield return new BalancedTextClassificationRecipe();
- yield return new AccuracyFocusedRecipe();
- yield return new ExplorationComboRecipe();
- yield return new TreeLeafRecipe();
+ yield return new BalancedTextClassificationRecipe(env);
+ yield return new AccuracyFocusedRecipe(env);
+ yield return new ExplorationComboRecipe(env);
+ yield return new TreeLeafRecipe(env);
}
public abstract class Recipe
@@ -210,6 +210,14 @@ protected override IEnumerable ApplyCore(Type predictorType,
public abstract class MultiClassRecipies : Recipe
{
+ protected IHostEnvironment Host { get; }
+
+ protected MultiClassRecipies(IHostEnvironment host)
+ {
+ Contracts.CheckValue(host, nameof(host));
+ Host = host;
+ }
+
public override List AllowedTransforms() => base.AllowedTransforms().Where(
expert =>
expert != typeof(TransformInference.Experts.Text) &&
@@ -227,6 +235,11 @@ public override List AllowedTransforms() => base.AllowedTransforms().Where
public sealed class BalancedTextClassificationRecipe : MultiClassRecipies
{
+ public BalancedTextClassificationRecipe(IHostEnvironment host)
+ : base(host)
+ {
+ }
+
public override List QualifierTransforms()
=> new List { typeof(TransformInference.Experts.TextBiGramTriGram) };
@@ -236,12 +249,12 @@ protected override IEnumerable ApplyCore(Type predictorType,
SuggestedRecipe.SuggestedLearner learner = new SuggestedRecipe.SuggestedLearner();
if (predictorType == typeof(SignatureMultiClassClassifierTrainer))
{
- learner.LoadableClassInfo = ComponentCatalog.GetLoadableClassInfo("OVA");
+ learner.LoadableClassInfo = Host.ComponentCatalog.GetLoadableClassInfo("OVA");
learner.Settings = "p=AveragedPerceptron{iter=10}";
}
else
{
- learner.LoadableClassInfo = ComponentCatalog.GetLoadableClassInfo(Learners.AveragedPerceptronTrainer.LoadNameValue);
+ learner.LoadableClassInfo = Host.ComponentCatalog.GetLoadableClassInfo(Learners.AveragedPerceptronTrainer.LoadNameValue);
learner.Settings = "iter=10";
var epInput = new Legacy.Trainers.AveragedPerceptronBinaryClassifier
{
@@ -259,6 +272,11 @@ protected override IEnumerable ApplyCore(Type predictorType,
public sealed class AccuracyFocusedRecipe : MultiClassRecipies
{
+ public AccuracyFocusedRecipe(IHostEnvironment host)
+ : base(host)
+ {
+ }
+
public override List QualifierTransforms()
=> new List { typeof(TransformInference.Experts.TextUniGramTriGram) };
@@ -268,13 +286,13 @@ protected override IEnumerable ApplyCore(Type predictorType,
SuggestedRecipe.SuggestedLearner learner = new SuggestedRecipe.SuggestedLearner();
if (predictorType == typeof(SignatureMultiClassClassifierTrainer))
{
- learner.LoadableClassInfo = ComponentCatalog.GetLoadableClassInfo("OVA");
+ learner.LoadableClassInfo = Host.ComponentCatalog.GetLoadableClassInfo("OVA");
learner.Settings = "p=FastTreeBinaryClassification";
}
else
{
learner.LoadableClassInfo =
- ComponentCatalog.GetLoadableClassInfo(FastTreeBinaryClassificationTrainer.LoadNameValue);
+ Host.ComponentCatalog.GetLoadableClassInfo(FastTreeBinaryClassificationTrainer.LoadNameValue);
learner.Settings = "";
var epInput = new Legacy.Trainers.FastTreeBinaryClassifier();
learner.PipelineNode = new TrainerPipelineNode(epInput);
@@ -288,6 +306,11 @@ protected override IEnumerable ApplyCore(Type predictorType,
public sealed class ExplorationComboRecipe : MultiClassRecipies
{
+ public ExplorationComboRecipe(IHostEnvironment host)
+ : base(host)
+ {
+ }
+
public override List QualifierTransforms()
=> new List { typeof(TransformInference.Experts.SdcaTransform) };
@@ -298,12 +321,12 @@ protected override IEnumerable ApplyCore(Type predictorType,
if (predictorType == typeof(SignatureMultiClassClassifierTrainer))
{
learner.LoadableClassInfo =
- ComponentCatalog.GetLoadableClassInfo(Learners.SdcaMultiClassTrainer.LoadNameValue);
+ Host.ComponentCatalog.GetLoadableClassInfo(Learners.SdcaMultiClassTrainer.LoadNameValue);
}
else
{
learner.LoadableClassInfo =
- ComponentCatalog.GetLoadableClassInfo(Learners.LinearClassificationTrainer.LoadNameValue);
+ Host.ComponentCatalog.GetLoadableClassInfo(Learners.LinearClassificationTrainer.LoadNameValue);
var epInput = new Legacy.Trainers.StochasticDualCoordinateAscentBinaryClassifier();
learner.PipelineNode = new TrainerPipelineNode(epInput);
}
@@ -317,6 +340,11 @@ protected override IEnumerable ApplyCore(Type predictorType,
public sealed class TreeLeafRecipe : MultiClassRecipies
{
+ public TreeLeafRecipe(IHostEnvironment host)
+ : base(host)
+ {
+ }
+
public override List QualifierTransforms()
=> new List { typeof(TransformInference.Experts.NaiveBayesTransform) };
@@ -325,7 +353,7 @@ protected override IEnumerable ApplyCore(Type predictorType,
{
SuggestedRecipe.SuggestedLearner learner = new SuggestedRecipe.SuggestedLearner();
learner.LoadableClassInfo =
- ComponentCatalog.GetLoadableClassInfo(Learners.MultiClassNaiveBayesTrainer.LoadName);
+ Host.ComponentCatalog.GetLoadableClassInfo(Learners.MultiClassNaiveBayesTrainer.LoadName);
learner.Settings = "";
var epInput = new Legacy.Trainers.NaiveBayesClassifier();
learner.PipelineNode = new TrainerPipelineNode(epInput);
@@ -403,7 +431,7 @@ public static SuggestedRecipe[] InferRecipesFromData(IHostEnvironment env, strin
AllowQuoting = splitResult.AllowQuote
};
- settingsString = CommandLine.CmdParser.GetSettings(ch, finalLoaderArgs, new TextLoader.Arguments());
+ settingsString = CommandLine.CmdParser.GetSettings(h, finalLoaderArgs, new TextLoader.Arguments());
ch.Info($"Loader options: {settingsString}");
ch.Info("Inferring recipes");
@@ -440,7 +468,7 @@ public static InferenceResult InferRecipes(IHostEnvironment env, TransformInfere
using (var ch = h.Start("InferRecipes"))
{
var list = new List();
- foreach (var recipe in GetRecipes())
+ foreach (var recipe in GetRecipes(h))
list.AddRange(recipe.Apply(transformInferenceResult, predictorType, ch));
if (list.Count == 0)
diff --git a/src/Microsoft.ML.ResultProcessor/Microsoft.ML.ResultProcessor.csproj b/src/Microsoft.ML.ResultProcessor/Microsoft.ML.ResultProcessor.csproj
index b420da0eb0..e5610126df 100644
--- a/src/Microsoft.ML.ResultProcessor/Microsoft.ML.ResultProcessor.csproj
+++ b/src/Microsoft.ML.ResultProcessor/Microsoft.ML.ResultProcessor.csproj
@@ -1,4 +1,4 @@
-
+
netstandard2.0
@@ -7,6 +7,10 @@
true
+
+
+
+
diff --git a/src/Microsoft.ML.ResultProcessor/ResultProcessor.cs b/src/Microsoft.ML.ResultProcessor/ResultProcessor.cs
index eb42e452bd..b142cea0ba 100644
--- a/src/Microsoft.ML.ResultProcessor/ResultProcessor.cs
+++ b/src/Microsoft.ML.ResultProcessor/ResultProcessor.cs
@@ -12,7 +12,6 @@
using Microsoft.ML.Runtime.Command;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
-using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Model;
using Microsoft.ML.Runtime.Tools;
@@ -152,9 +151,9 @@ public void GetDefaultSettingValues(IHostEnvironment env, string predictorName,
///
private Dictionary GetDefaultSettings(IHostEnvironment env, string predictorName, string[] extraAssemblies = null)
{
- ComponentCatalog.CacheClassesExtra(extraAssemblies);
+ AssemblyLoadingUtils.LoadAndRegister(env, extraAssemblies);
- var cls = ComponentCatalog.GetLoadableClassInfo(predictorName);
+ var cls = env.ComponentCatalog.GetLoadableClassInfo(predictorName);
if (cls == null)
{
Console.Error.WriteLine("Can't load trainer '{0}'", predictorName);
@@ -521,7 +520,7 @@ private static bool ValidateMamlOutput(string filename, string[] rawLines, out L
ICommandLineComponentFactory commandLineTrainer = trainer as ICommandLineComponentFactory;
Contracts.AssertValue(commandLineTrainer, "ResultProcessor can only work with ICommandLineComponentFactory.");
- trainerClass = ComponentCatalog.GetLoadableClassInfo(commandLineTrainer.Name);
+ trainerClass = env.ComponentCatalog.GetLoadableClassInfo(commandLineTrainer.Name);
trainerArgs = trainerClass.CreateArguments();
Dictionary predictorSettings;
if (trainerArgs == null)
@@ -678,7 +677,12 @@ public static bool ParseCommandArguments(IHostEnvironment env, string commandlin
return false;
}
- commandClass = ComponentCatalog.GetLoadableClassInfo(kind);
+ commandClass = env.ComponentCatalog.GetLoadableClassInfo(kind);
+ if (commandClass == null)
+ {
+ commandArgs = null;
+ return false;
+ }
commandArgs = commandClass.CreateArguments();
CmdParser.ParseArguments(env, settings, commandArgs);
return true;
@@ -1147,10 +1151,15 @@ private static object Load(Stream stream)
}
public static int Main(string[] args)
+ {
+ return Main(new ConsoleEnvironment(42), args);
+ }
+
+ public static int Main(IHostEnvironment env, string[] args)
{
try
{
- Run(args);
+ Run(env, args);
return 0;
}
catch (Exception e)
@@ -1170,10 +1179,9 @@ public static int Main(string[] args)
}
}
- protected static void Run(string[] args)
+ protected static void Run(IHostEnvironment env, string[] args)
{
ResultProcessorArguments cmd = new ResultProcessorArguments();
- ConsoleEnvironment env = new ConsoleEnvironment(42);
List predictorResultsList = new List();
PredictionUtil.ParseArguments(env, cmd, PredictionUtil.CombineSettings(args));
@@ -1185,8 +1193,8 @@ protected static void Run(string[] args)
if (cmd.IncludePerFoldResults)
cmd.PerFoldResultSeparator = "" + PredictionUtil.SepCharFromString(cmd.PerFoldResultSeparator);
- foreach (var dll in cmd.ExtraAssemblies)
- ComponentCatalog.LoadAssembly(dll);
+
+ AssemblyLoadingUtils.LoadAndRegister(env, cmd.ExtraAssemblies);
if (cmd.Metrics.Length == 0)
cmd.Metrics = null;
diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs
index f5da9327c1..8e95682b51 100644
--- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs
+++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs
@@ -40,7 +40,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001,
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(FieldAwareFactorizationMachinePredictor).Assembly.FullName);
}
internal FieldAwareFactorizationMachinePredictor(IHostEnvironment env, bool norm, int fieldCount, int featureCount, int latentDim,
@@ -353,7 +354,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(FieldAwareFactorizationMachinePredictionTransformer).Assembly.FullName);
}
private static FieldAwareFactorizationMachinePredictionTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs
index eb0644d2da..ed41a14174 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs
@@ -409,7 +409,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00020002, // Added model statistics
verReadableCur: 0x00020001,
verWeCanReadBack: 0x00020001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(LinearBinaryPredictor).Assembly.FullName);
}
///
@@ -597,7 +598,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00020001, // Fixed sparse serialization
verReadableCur: 0x00020001,
verWeCanReadBack: 0x00020001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(LinearRegressionPredictor).Assembly.FullName);
}
///
@@ -679,7 +681,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00020001, // Fixed sparse serialization
verReadableCur: 0x00020001,
verWeCanReadBack: 0x00020001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(PoissonRegressionPredictor).Assembly.FullName);
}
internal PoissonRegressionPredictor(IHostEnvironment env, ref VBuffer weights, Float bias)
diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs
index b8986cce77..2a90e6cb2a 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs
@@ -304,7 +304,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010003, // Added model stats
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(MulticlassLogisticRegressionPredictor).Assembly.FullName);
}
private const string ModelStatsSubModelFilename = "ModelStats";
diff --git a/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs b/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs
index e1594041f8..cc28c1caf1 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs
@@ -59,7 +59,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(LinearModelStatistics).Assembly.FullName);
}
private readonly IHostEnvironment _env;
diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs
index 8c96ee1e0b..0817db47a5 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs
@@ -151,7 +151,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(MultiClassNaiveBayesPredictor).Assembly.FullName);
}
private readonly int[] _labelHistogram;
diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs
index b918c5acce..6229e9f70f 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs
@@ -214,7 +214,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(OvaPredictor).Assembly.FullName);
}
private const string SubPredictorFmt = "SubPredictor_{0:000}";
diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs
index 6a8736bad1..da9bfc99b6 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs
@@ -225,7 +225,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(PkpdPredictor).Assembly.FullName);
}
private const string SubPredictorFmt = "SubPredictor_{0:000}";
diff --git a/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs b/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs
index e823d1fcd1..804943ac0d 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs
@@ -120,7 +120,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001,
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(RandomPredictor).Assembly.FullName);
}
// Keep all the serializable state here.
@@ -358,7 +359,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001,
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(PriorPredictor).Assembly.FullName);
}
private readonly float _prob;
diff --git a/src/Microsoft.ML.Sweeper/SweepResultEvaluator.cs b/src/Microsoft.ML.Sweeper/SweepResultEvaluator.cs
index ec80ce89b1..fded15ddaf 100644
--- a/src/Microsoft.ML.Sweeper/SweepResultEvaluator.cs
+++ b/src/Microsoft.ML.Sweeper/SweepResultEvaluator.cs
@@ -42,7 +42,7 @@ private string FindMetric(string userMetric, out bool maximizing)
{
StringBuilder sb = new StringBuilder();
- var evaluators = ComponentCatalog.GetAllDerivedClasses(typeof(IMamlEvaluator), typeof(SignatureMamlEvaluator));
+ var evaluators = _host.ComponentCatalog.GetAllDerivedClasses(typeof(IMamlEvaluator), typeof(SignatureMamlEvaluator));
foreach (var evalInfo in evaluators)
{
var args = evalInfo.CreateArguments();
diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
index 69532de8fc..0afb03b7a2 100644
--- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
+++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
@@ -76,7 +76,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010002, // Upgraded when change for multiple outputs was implemented.
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(TensorFlowTransform).Assembly.FullName);
}
///
diff --git a/src/Microsoft.ML.Transforms/BootstrapSampleTransform.cs b/src/Microsoft.ML.Transforms/BootstrapSampleTransform.cs
index 91106bd445..dc4bde762c 100644
--- a/src/Microsoft.ML.Transforms/BootstrapSampleTransform.cs
+++ b/src/Microsoft.ML.Transforms/BootstrapSampleTransform.cs
@@ -59,7 +59,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(BootstrapSampleTransform).Assembly.FullName);
}
internal const string RegistrationName = "BootstrapSample";
diff --git a/src/Microsoft.ML.Transforms/CompositeTransform.cs b/src/Microsoft.ML.Transforms/CompositeTransform.cs
index 4dfeb3b312..48cdfe345f 100644
--- a/src/Microsoft.ML.Transforms/CompositeTransform.cs
+++ b/src/Microsoft.ML.Transforms/CompositeTransform.cs
@@ -30,7 +30,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(CompositeTransform).Assembly.FullName);
}
public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
diff --git a/src/Microsoft.ML.Transforms/FourierDistributionSampler.cs b/src/Microsoft.ML.Transforms/FourierDistributionSampler.cs
index fbb544da50..ae9a721bd9 100644
--- a/src/Microsoft.ML.Transforms/FourierDistributionSampler.cs
+++ b/src/Microsoft.ML.Transforms/FourierDistributionSampler.cs
@@ -55,7 +55,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(GaussianFourierSampler).Assembly.FullName);
}
public const string LoadName = "GaussianRandom";
@@ -130,7 +131,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(LaplacianFourierSampler).Assembly.FullName);
}
public const string LoaderSignature = "RandLaplacianFourierExec";
diff --git a/src/Microsoft.ML.Transforms/GcnTransform.cs b/src/Microsoft.ML.Transforms/GcnTransform.cs
index 67a6bbd951..d8d1618a41 100644
--- a/src/Microsoft.ML.Transforms/GcnTransform.cs
+++ b/src/Microsoft.ML.Transforms/GcnTransform.cs
@@ -236,7 +236,8 @@ private static VersionInfo GetVersionInfo()
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
- loaderSignatureAlt: LoaderSignatureOld);
+ loaderSignatureAlt: LoaderSignatureOld,
+ loaderAssemblyName: typeof(LpNormNormalizerTransform).Assembly.FullName);
}
private const string RegistrationName = "LpNormNormalizer";
diff --git a/src/Microsoft.ML.Transforms/GroupTransform.cs b/src/Microsoft.ML.Transforms/GroupTransform.cs
index c3a0bf8736..a40a9ea827 100644
--- a/src/Microsoft.ML.Transforms/GroupTransform.cs
+++ b/src/Microsoft.ML.Transforms/GroupTransform.cs
@@ -65,7 +65,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(GroupTransform).Assembly.FullName);
}
// REVIEW: maybe we want to have an option to keep all non-group scalar columns, as opposed to
diff --git a/src/Microsoft.ML.Transforms/HashJoinTransform.cs b/src/Microsoft.ML.Transforms/HashJoinTransform.cs
index d8e58c84ed..bd3d007fa8 100644
--- a/src/Microsoft.ML.Transforms/HashJoinTransform.cs
+++ b/src/Microsoft.ML.Transforms/HashJoinTransform.cs
@@ -168,7 +168,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010005, // Hash fix
verReadableCur: 0x00010005,
verWeCanReadBack: 0x00010005,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(HashJoinTransform).Assembly.FullName);
}
private readonly ColumnInfoEx[] _exes;
diff --git a/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs b/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs
index f1503c336a..4f947415af 100644
--- a/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs
+++ b/src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs
@@ -59,7 +59,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00000001, // Initial
verReadableCur: 0x00000001,
verWeCanReadBack: 0x00000001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(KeyToBinaryVectorTransform).Assembly.FullName);
}
private const string RegistrationName = "KeyToBinary";
diff --git a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransform.cs b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransform.cs
index 0a8d247766..41412e6e26 100644
--- a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransform.cs
+++ b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransform.cs
@@ -59,7 +59,8 @@ private static VersionInfo GetVersionInfo()
loaderSignature: LoaderSignature,
// This is an older name and can be removed once we don't care about old code
// being able to load this.
- loaderSignatureAlt: "MissingFeatureFunction");
+ loaderSignatureAlt: "MissingFeatureFunction",
+ loaderAssemblyName: typeof(MissingValueIndicatorTransform).Assembly.FullName);
}
private const string RegistrationName = "MissingIndicator";
diff --git a/src/Microsoft.ML.Transforms/NADropTransform.cs b/src/Microsoft.ML.Transforms/NADropTransform.cs
index ea47fb9d1a..6d67fa5efe 100644
--- a/src/Microsoft.ML.Transforms/NADropTransform.cs
+++ b/src/Microsoft.ML.Transforms/NADropTransform.cs
@@ -59,7 +59,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(NADropTransform).Assembly.FullName);
}
private const string RegistrationName = "DropNAs";
diff --git a/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs b/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs
index 39b3d650d2..044751f459 100644
--- a/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs
+++ b/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs
@@ -59,7 +59,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(NAIndicatorTransform).Assembly.FullName);
}
internal const string Summary = "Create a boolean output column with the same number of slots as the input column, where the output value"
diff --git a/src/Microsoft.ML.Transforms/NAReplaceTransform.cs b/src/Microsoft.ML.Transforms/NAReplaceTransform.cs
index 1f3b4ee220..16ebdcae6a 100644
--- a/src/Microsoft.ML.Transforms/NAReplaceTransform.cs
+++ b/src/Microsoft.ML.Transforms/NAReplaceTransform.cs
@@ -138,7 +138,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x0010002, // Added imputation methods.
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoadName);
+ loaderSignature: LoadName,
+ loaderAssemblyName: typeof(NAReplaceTransform).Assembly.FullName);
}
internal const string Summary = "Create an output column of the same type and size of the input column, where missing values "
diff --git a/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs b/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs
index 5117496194..b78f008215 100644
--- a/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs
+++ b/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs
@@ -226,7 +226,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010002, // Save the input schema, for metadata
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010002,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(OptionalColumnTransform).Assembly.FullName);
}
private readonly Bindings _bindings;
diff --git a/src/Microsoft.ML.Transforms/ProduceIdTransform.cs b/src/Microsoft.ML.Transforms/ProduceIdTransform.cs
index 0934fd0086..b87476cefa 100644
--- a/src/Microsoft.ML.Transforms/ProduceIdTransform.cs
+++ b/src/Microsoft.ML.Transforms/ProduceIdTransform.cs
@@ -84,7 +84,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(ProduceIdTransform).Assembly.FullName);
}
private readonly Bindings _bindings;
diff --git a/src/Microsoft.ML.Transforms/RffTransform.cs b/src/Microsoft.ML.Transforms/RffTransform.cs
index 7cb51a6f65..866683bb99 100644
--- a/src/Microsoft.ML.Transforms/RffTransform.cs
+++ b/src/Microsoft.ML.Transforms/RffTransform.cs
@@ -220,7 +220,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(RffTransform).Assembly.FullName);
}
// These are parallel to Infos.
diff --git a/src/Microsoft.ML.Transforms/TermLookupTransform.cs b/src/Microsoft.ML.Transforms/TermLookupTransform.cs
index 98dc7a0933..9cadc31d9f 100644
--- a/src/Microsoft.ML.Transforms/TermLookupTransform.cs
+++ b/src/Microsoft.ML.Transforms/TermLookupTransform.cs
@@ -277,7 +277,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010002, // Dropped sizeof(Float).
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010002,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(TermLookupTransform).Assembly.FullName);
}
// This is the byte array containing the binary .idv file contents for the lookup data.
diff --git a/src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs b/src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs
index 47c6762d1b..da10ce2b65 100644
--- a/src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs
+++ b/src/Microsoft.ML.Transforms/Text/CharTokenizeTransform.cs
@@ -75,7 +75,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010002, // Updated to use UnitSeparator character instead of using for vector inputs.
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(CharTokenizeTransform).Assembly.FullName);
}
// Controls whether to mark the beginning/end of each row/slot with TextStartMarker/TextEndMarker.
diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs
index 380399d96a..f15546d2cc 100644
--- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs
+++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs
@@ -291,7 +291,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(LdaTransform).Assembly.FullName);
}
private readonly ColInfoEx[] _exes;
diff --git a/src/Microsoft.ML.Transforms/Text/NgramHashTransform.cs b/src/Microsoft.ML.Transforms/Text/NgramHashTransform.cs
index 8ab64f9da4..72eb1d650a 100644
--- a/src/Microsoft.ML.Transforms/Text/NgramHashTransform.cs
+++ b/src/Microsoft.ML.Transforms/Text/NgramHashTransform.cs
@@ -317,7 +317,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010002, // Invert hash key values, hash fix
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010002,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(NgramHashTransform).Assembly.FullName);
}
private readonly Bindings _bindings;
diff --git a/src/Microsoft.ML.Transforms/Text/NgramTransform.cs b/src/Microsoft.ML.Transforms/Text/NgramTransform.cs
index a696f2b092..3f24f290f9 100644
--- a/src/Microsoft.ML.Transforms/Text/NgramTransform.cs
+++ b/src/Microsoft.ML.Transforms/Text/NgramTransform.cs
@@ -200,7 +200,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010002, // Add support for TF-IDF
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(NgramTransform).Assembly.FullName);
}
private readonly VectorType[] _types;
diff --git a/src/Microsoft.ML.Transforms/Text/StopWordsRemoverTransform.cs b/src/Microsoft.ML.Transforms/Text/StopWordsRemoverTransform.cs
index d278326b33..42fac1e2bb 100644
--- a/src/Microsoft.ML.Transforms/Text/StopWordsRemoverTransform.cs
+++ b/src/Microsoft.ML.Transforms/Text/StopWordsRemoverTransform.cs
@@ -235,7 +235,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(StopWordsRemoverTransform).Assembly.FullName);
}
private readonly bool?[] _resourcesExist;
@@ -602,10 +603,11 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(CustomStopWordsRemoverTransform).Assembly.FullName);
}
- public const string StopwrodsManagerLoaderSignature = "CustomStopWordsManager";
+ public const string StopwordsManagerLoaderSignature = "CustomStopWordsManager";
private static VersionInfo GetStopwrodsManagerVersionInfo()
{
return new VersionInfo(
@@ -613,7 +615,8 @@ private static VersionInfo GetStopwrodsManagerVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: StopwrodsManagerLoaderSignature);
+ loaderSignature: StopwordsManagerLoaderSignature,
+ loaderAssemblyName: typeof(CustomStopWordsRemoverTransform).Assembly.FullName);
}
private static readonly ColumnType _outputType = new VectorType(TextType.Instance);
diff --git a/src/Microsoft.ML.Transforms/Text/TextNormalizerTransform.cs b/src/Microsoft.ML.Transforms/Text/TextNormalizerTransform.cs
index c96b1511d1..87a33fe2bf 100644
--- a/src/Microsoft.ML.Transforms/Text/TextNormalizerTransform.cs
+++ b/src/Microsoft.ML.Transforms/Text/TextNormalizerTransform.cs
@@ -86,7 +86,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(TextNormalizerTransform).Assembly.FullName);
}
private const string RegistrationName = "TextNormalizer";
diff --git a/src/Microsoft.ML.Transforms/Text/TextTransform.cs b/src/Microsoft.ML.Transforms/Text/TextTransform.cs
index cc801dd311..77e7739ad7 100644
--- a/src/Microsoft.ML.Transforms/Text/TextTransform.cs
+++ b/src/Microsoft.ML.Transforms/Text/TextTransform.cs
@@ -639,7 +639,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(Transformer).Assembly.FullName);
}
}
diff --git a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsTransform.cs b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsTransform.cs
index 5a5b382d75..b9ef52a414 100644
--- a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsTransform.cs
+++ b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsTransform.cs
@@ -79,7 +79,8 @@ public static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, //Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(WordEmbeddingsTransform).Assembly.FullName);
}
private readonly PretrainedModelKind? _modelKind;
diff --git a/src/Microsoft.ML.Transforms/Text/WordTokenizeTransform.cs b/src/Microsoft.ML.Transforms/Text/WordTokenizeTransform.cs
index 35bfbd925c..391cfa0327 100644
--- a/src/Microsoft.ML.Transforms/Text/WordTokenizeTransform.cs
+++ b/src/Microsoft.ML.Transforms/Text/WordTokenizeTransform.cs
@@ -141,7 +141,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(DelimitedTokenizeTransform).Assembly.FullName);
}
public override bool CanSavePfa => true;
diff --git a/src/Microsoft.ML.Transforms/UngroupTransform.cs b/src/Microsoft.ML.Transforms/UngroupTransform.cs
index d77cc48bbd..cc49ffe8f3 100644
--- a/src/Microsoft.ML.Transforms/UngroupTransform.cs
+++ b/src/Microsoft.ML.Transforms/UngroupTransform.cs
@@ -58,7 +58,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(UngroupTransform).Assembly.FullName);
}
///
diff --git a/src/Microsoft.ML.Transforms/WhiteningTransform.cs b/src/Microsoft.ML.Transforms/WhiteningTransform.cs
index 5d1b3188b7..21c0870633 100644
--- a/src/Microsoft.ML.Transforms/WhiteningTransform.cs
+++ b/src/Microsoft.ML.Transforms/WhiteningTransform.cs
@@ -211,7 +211,8 @@ private static VersionInfo GetVersionInfo()
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
- loaderSignatureAlt: LoaderSignatureOld);
+ loaderSignatureAlt: LoaderSignatureOld,
+ loaderAssemblyName: typeof(WhiteningTransform).Assembly.FullName);
}
private readonly ColInfoEx[] _exes;
diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs
index b3d7068041..11407a3ec4 100644
--- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs
+++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs
@@ -6,6 +6,7 @@
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.TestFramework;
+using Microsoft.ML.Transforms;
using System;
using System.Collections.Generic;
using System.Linq;
diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEarlyStoppingCriteria.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEarlyStoppingCriteria.cs
index 53b77440dc..3798246fa8 100644
--- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEarlyStoppingCriteria.cs
+++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEarlyStoppingCriteria.cs
@@ -13,8 +13,10 @@ public sealed class TestEarlyStoppingCriteria
{
private IEarlyStoppingCriterion CreateEarlyStoppingCriterion(string name, string args, bool lowerIsBetter)
{
+ var env = new ConsoleEnvironment()
+ .AddStandardComponents();
var sub = new SubComponent(name, args);
- return sub.CreateInstance(new ConsoleEnvironment(), lowerIsBetter);
+ return sub.CreateInstance(env, lowerIsBetter);
}
[Fact]
diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
index d674d99c61..d94192f12e 100644
--- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
+++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
@@ -15,11 +15,17 @@
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.EntryPoints.JsonUtils;
using Microsoft.ML.Runtime.FastTree;
+using Microsoft.ML.Runtime.ImageAnalytics;
using Microsoft.ML.Runtime.Internal.Calibration;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Learners;
+using Microsoft.ML.Runtime.LightGBM;
+using Microsoft.ML.Runtime.Model.Onnx;
using Microsoft.ML.Runtime.PCA;
+using Microsoft.ML.Runtime.PipelineInference;
+using Microsoft.ML.Runtime.SymSgd;
using Microsoft.ML.Runtime.TextAnalytics;
+using Microsoft.ML.Transforms;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using Xunit;
@@ -237,37 +243,19 @@ private string GetBuildPrefix()
[Fact(Skip = "Execute this test if you want to regenerate ep-list and _manifest.json")]
public void RegenerateEntryPointCatalog()
{
+ var (epListContents, jObj) = BuildManifests();
+
var buildPrefix = GetBuildPrefix();
var epListFile = buildPrefix + "_ep-list.tsv";
- var manifestFile = buildPrefix + "_manifest.json";
var entryPointsSubDir = Path.Combine("..", "Common", "EntryPoints");
var catalog = ModuleCatalog.CreateInstance(Env);
var epListPath = GetBaselinePath(entryPointsSubDir, epListFile);
DeleteOutputPath(epListPath);
- var regex = new Regex(@"\r\n?|\n", RegexOptions.Compiled);
- File.WriteAllLines(epListPath, catalog.AllEntryPoints()
- .Select(x => string.Join("\t",
- x.Name,
- regex.Replace(x.Description, ""),
- x.Method.DeclaringType,
- x.Method.Name,
- x.InputType,
- x.OutputType)
- .Replace(Environment.NewLine, ""))
- .OrderBy(x => x));
-
+ File.WriteAllLines(epListPath, epListContents);
- var jObj = JsonManifestUtils.BuildAllManifests(Env, catalog);
-
- //clean up the description from the new line characters
- if (jObj[FieldNames.TopEntryPoints] != null && jObj[FieldNames.TopEntryPoints] is JArray)
- {
- foreach (JToken entry in jObj[FieldNames.TopEntryPoints].Children())
- if (entry[FieldNames.Desc] != null)
- entry[FieldNames.Desc] = regex.Replace(entry[FieldNames.Desc].ToString(), "");
- }
+ var manifestFile = buildPrefix + "_manifest.json";
var manifestPath = GetBaselinePath(entryPointsSubDir, manifestFile);
DeleteOutputPath(manifestPath);
@@ -280,20 +268,49 @@ public void RegenerateEntryPointCatalog()
}
}
-
[Fact]
public void EntryPointCatalog()
{
+ var (epListContents, jObj) = BuildManifests();
+
var buildPrefix = GetBuildPrefix();
var epListFile = buildPrefix + "_ep-list.tsv";
- var manifestFile = buildPrefix + "_manifest.json";
var entryPointsSubDir = Path.Combine("..", "Common", "EntryPoints");
var catalog = ModuleCatalog.CreateInstance(Env);
var path = DeleteOutputPath(entryPointsSubDir, epListFile);
+ File.WriteAllLines(path, epListContents);
+
+ CheckEquality(entryPointsSubDir, epListFile);
+
+ var manifestFile = buildPrefix + "_manifest.json";
+ var jPath = DeleteOutputPath(entryPointsSubDir, manifestFile);
+ using (var file = File.OpenWrite(jPath))
+ using (var writer = new StreamWriter(file))
+ using (var jw = new JsonTextWriter(writer))
+ {
+ jw.Formatting = Formatting.Indented;
+ jObj.WriteTo(jw);
+ }
+
+ CheckEquality(entryPointsSubDir, manifestFile);
+ Done();
+ }
+
+ private (IEnumerable epListContents, JObject manifest) BuildManifests()
+ {
+ Env.ComponentCatalog.RegisterAssembly(typeof(LightGbmBinaryPredictor).Assembly);
+ Env.ComponentCatalog.RegisterAssembly(typeof(TensorFlowTransform).Assembly);
+ Env.ComponentCatalog.RegisterAssembly(typeof(ImageLoaderTransform).Assembly);
+ Env.ComponentCatalog.RegisterAssembly(typeof(SymSgdClassificationTrainer).Assembly);
+ Env.ComponentCatalog.RegisterAssembly(typeof(AutoInference).Assembly);
+ Env.ComponentCatalog.RegisterAssembly(typeof(SaveOnnxCommand).Assembly);
+
+ var catalog = ModuleCatalog.CreateInstance(Env);
+
var regex = new Regex(@"\r\n?|\n", RegexOptions.Compiled);
- File.WriteAllLines(path, catalog.AllEntryPoints()
+ var epListContents = catalog.AllEntryPoints()
.Select(x => string.Join("\t",
x.Name,
regex.Replace(x.Description, ""),
@@ -302,39 +319,27 @@ public void EntryPointCatalog()
x.InputType,
x.OutputType)
.Replace(Environment.NewLine, ""))
- .OrderBy(x => x));
+ .OrderBy(x => x);
- CheckEquality(entryPointsSubDir, epListFile);
-
- var jObj = JsonManifestUtils.BuildAllManifests(Env, catalog);
+ var manifest = JsonManifestUtils.BuildAllManifests(Env, catalog);
//clean up the description from the new line characters
- if (jObj[FieldNames.TopEntryPoints] != null && jObj[FieldNames.TopEntryPoints] is JArray)
+ if (manifest[FieldNames.TopEntryPoints] != null && manifest[FieldNames.TopEntryPoints] is JArray)
{
- foreach (JToken entry in jObj[FieldNames.TopEntryPoints].Children())
+ foreach (JToken entry in manifest[FieldNames.TopEntryPoints].Children())
if (entry[FieldNames.Desc] != null)
entry[FieldNames.Desc] = regex.Replace(entry[FieldNames.Desc].ToString(), "");
}
- var jPath = DeleteOutputPath(entryPointsSubDir, manifestFile);
- using (var file = File.OpenWrite(jPath))
- using (var writer = new StreamWriter(file))
- using (var jw = new JsonTextWriter(writer))
- {
- jw.Formatting = Formatting.Indented;
- jObj.WriteTo(jw);
- }
-
- CheckEquality(entryPointsSubDir, manifestFile);
- Done();
+ return (epListContents, manifest);
}
[Fact]
public void EntryPointInputBuilderOptionals()
{
- var catelog = ModuleCatalog.CreateInstance(Env);
+ var catalog = ModuleCatalog.CreateInstance(Env);
- InputBuilder ib1 = new InputBuilder(Env, typeof(LogisticRegression.Arguments), catelog);
+ InputBuilder ib1 = new InputBuilder(Env, typeof(LogisticRegression.Arguments), catalog);
// Ensure that InputBuilder unwraps the Optional correctly.
var weightType = ib1.GetFieldTypeOrNull("WeightColumn");
Assert.True(weightType.Equals(typeof(string)));
@@ -1794,12 +1799,14 @@ public void EntryPointEvaluateRanking()
[Fact]
public void EntryPointLightGbmBinary()
{
+ Env.ComponentCatalog.RegisterAssembly(typeof(LightGbmBinaryPredictor).Assembly);
TestEntryPointRoutine("breast-cancer.txt", "Trainers.LightGbmBinaryClassifier");
}
[Fact]
public void EntryPointLightGbmMultiClass()
{
+ Env.ComponentCatalog.RegisterAssembly(typeof(LightGbmBinaryPredictor).Assembly);
TestEntryPointRoutine(GetDataPath(@"iris.txt"), "Trainers.LightGbmClassifier");
}
@@ -3728,6 +3735,8 @@ public void EntryPointWordEmbeddings()
[Fact]
public void EntryPointTensorFlowTransform()
{
+ Env.ComponentCatalog.RegisterAssembly(typeof(TensorFlowTransform).Assembly);
+
TestEntryPointPipelineRoutine(GetDataPath("Train-Tiny-28x28.txt"), "col=Label:R4:0 col=Placeholder:R4:1-784",
new[] { "Transforms.TensorFlowScorer" },
new[]
diff --git a/test/Microsoft.ML.InferenceTesting/GenerateSweepCandidatesCommand.cs b/test/Microsoft.ML.InferenceTesting/GenerateSweepCandidatesCommand.cs
index b964d4dae6..534ababbc7 100644
--- a/test/Microsoft.ML.InferenceTesting/GenerateSweepCandidatesCommand.cs
+++ b/test/Microsoft.ML.InferenceTesting/GenerateSweepCandidatesCommand.cs
@@ -85,7 +85,7 @@ public GenerateSweepCandidatesCommand(IHostEnvironment env, Arguments args)
if (!string.IsNullOrWhiteSpace(args.Sweeper))
{
- var info = ComponentCatalog.GetLoadableClassInfo(args.Sweeper);
+ var info = env.ComponentCatalog.GetLoadableClassInfo(args.Sweeper);
_host.CheckUserArg(info?.SignatureTypes[0] == typeof(SignatureSweeper), nameof(args.Sweeper),
"Please specify a valid sweeper.");
_sweeper = args.Sweeper;
@@ -95,7 +95,7 @@ public GenerateSweepCandidatesCommand(IHostEnvironment env, Arguments args)
if (!string.IsNullOrWhiteSpace(args.Mode))
{
- var info = ComponentCatalog.GetLoadableClassInfo(args.Mode);
+ var info = env.ComponentCatalog.GetLoadableClassInfo(args.Mode);
_host.CheckUserArg(info?.Type == typeof(TrainCommand) ||
info?.Type == typeof(TrainTestCommand) ||
info?.Type == typeof(CrossValidationCommand), nameof(args.Mode), "Invalid mode.");
diff --git a/test/Microsoft.ML.Predictor.Tests/CmdLine/CmdLineReverseTest.cs b/test/Microsoft.ML.Predictor.Tests/CmdLine/CmdLineReverseTest.cs
index 2369136f95..6d95829454 100644
--- a/test/Microsoft.ML.Predictor.Tests/CmdLine/CmdLineReverseTest.cs
+++ b/test/Microsoft.ML.Predictor.Tests/CmdLine/CmdLineReverseTest.cs
@@ -113,9 +113,9 @@ private class SimpleArg
/// ToString is overrided by CmdParser.GetSettings which is of primary for this test
///
///
- public string ToString(IExceptionContext ectx)
+ public string ToString(IHostEnvironment env)
{
- return CmdParser.GetSettings(ectx, this, new SimpleArg(), SettingsFlags.None);
+ return CmdParser.GetSettings(env, this, new SimpleArg(), SettingsFlags.None);
}
public override bool Equals(object obj)
diff --git a/test/Microsoft.ML.Predictor.Tests/TestAutoInference.cs b/test/Microsoft.ML.Predictor.Tests/TestAutoInference.cs
index 1b991300ea..9d0af99420 100644
--- a/test/Microsoft.ML.Predictor.Tests/TestAutoInference.cs
+++ b/test/Microsoft.ML.Predictor.Tests/TestAutoInference.cs
@@ -6,6 +6,7 @@
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.EntryPoints.JsonUtils;
using Microsoft.ML.Runtime.PipelineInference;
+using Microsoft.ML.TestFramework;
using Newtonsoft.Json.Linq;
using System.Collections.Generic;
using System.Linq;
@@ -26,7 +27,8 @@ public TestAutoInference(ITestOutputHelper helper)
[TestCategory("EntryPoints")]
public void TestLearn()
{
- using (var env = new ConsoleEnvironment())
+ using (var env = new LocalEnvironment()
+ .AddStandardComponents()) // AutoInference.InferPipelines uses ComponentCatalog to read text data
{
string pathData = GetDataPath("adult.train");
string pathDataTest = GetDataPath("adult.test");
@@ -72,7 +74,8 @@ public void TestLearn()
[Fact(Skip = "Need CoreTLC specific baseline update")]
public void TestTextDatasetLearn()
{
- using (var env = new ConsoleEnvironment())
+ using (var env = new LocalEnvironment()
+ .AddStandardComponents()) // AutoInference uses ComponentCatalog to find all learners
{
string pathData = GetDataPath(@"../UnitTest/tweets_labeled_10k_test_validation.tsv");
int batchSize = 5;
@@ -96,7 +99,8 @@ public void TestTextDatasetLearn()
[Fact]
public void TestPipelineNodeCloning()
{
- using (var env = new ConsoleEnvironment())
+ using (var env = new LocalEnvironment()
+ .AddStandardComponents()) // RecipeInference.AllowedLearners uses ComponentCatalog to find all learners
{
var lr1 = RecipeInference
.AllowedLearners(env, MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer)
@@ -138,15 +142,16 @@ public void TestHyperparameterFreezing()
int batchSize = 1;
int numIterations = 10;
int numTransformLevels = 3;
- using (var env = new ConsoleEnvironment())
+ using (var env = new LocalEnvironment()
+ .AddStandardComponents()) // AutoInference uses ComponentCatalog to find all learners
{
SupportedMetric metric = PipelineSweeperSupportedMetrics.GetSupportedMetric(PipelineSweeperSupportedMetrics.Metrics.Auc);
// Using the simple, uniform random sampling (with replacement) brain
- PipelineOptimizerBase autoMlBrain = new UniformRandomEngine(Env);
+ PipelineOptimizerBase autoMlBrain = new UniformRandomEngine(env);
// Run initial experiments
- var amls = AutoInference.InferPipelines(Env, autoMlBrain, pathData, "", out var _, numTransformLevels, batchSize,
+ var amls = AutoInference.InferPipelines(env, autoMlBrain, pathData, "", out var _, numTransformLevels, batchSize,
metric, out var bestPipeline, numOfSampleRows, new IterationTerminator(numIterations),
MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer);
@@ -186,15 +191,16 @@ public void TestRegressionPipelineWithMinimizingMetric()
int batchSize = 5;
int numIterations = 10;
int numTransformLevels = 1;
- using (var env = new ConsoleEnvironment())
+ using (var env = new LocalEnvironment()
+ .AddStandardComponents()) // AutoInference uses ComponentCatalog to find all learners
{
SupportedMetric metric = PipelineSweeperSupportedMetrics.GetSupportedMetric(PipelineSweeperSupportedMetrics.Metrics.AccuracyMicro);
// Using the simple, uniform random sampling (with replacement) brain
- PipelineOptimizerBase autoMlBrain = new UniformRandomEngine(Env);
+ PipelineOptimizerBase autoMlBrain = new UniformRandomEngine(env);
// Run initial experiments
- var amls = AutoInference.InferPipelines(Env, autoMlBrain, pathData, "", out var _, numTransformLevels, batchSize,
+ var amls = AutoInference.InferPipelines(env, autoMlBrain, pathData, "", out var _, numTransformLevels, batchSize,
metric, out var bestPipeline, numOfSampleRows, new IterationTerminator(numIterations),
MacroUtils.TrainerKinds.SignatureRegressorTrainer);
@@ -220,15 +226,16 @@ public void TestLearnerConstrainingByName()
int numIterations = 1;
int numTransformLevels = 2;
var retainedLearnerNames = new[] { $"LogisticRegressionBinaryClassifier", $"FastTreeBinaryClassifier" };
- using (var env = new ConsoleEnvironment())
+ using (var env = new LocalEnvironment()
+ .AddStandardComponents()) // AutoInference uses ComponentCatalog to find all learners
{
SupportedMetric metric = PipelineSweeperSupportedMetrics.GetSupportedMetric(PipelineSweeperSupportedMetrics.Metrics.Auc);
// Using the simple, uniform random sampling (with replacement) brain.
- PipelineOptimizerBase autoMlBrain = new UniformRandomEngine(Env);
+ PipelineOptimizerBase autoMlBrain = new UniformRandomEngine(env);
// Run initial experiment.
- var amls = AutoInference.InferPipelines(Env, autoMlBrain, pathData, "", out var _,
+ var amls = AutoInference.InferPipelines(env, autoMlBrain, pathData, "", out var _,
numTransformLevels, batchSize, metric, out var _, numOfSampleRows,
new IterationTerminator(numIterations), MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer);
diff --git a/test/Microsoft.ML.Predictor.Tests/TestPipelineSweeper.cs b/test/Microsoft.ML.Predictor.Tests/TestPipelineSweeper.cs
index 5f65ca4c28..cd86e83414 100644
--- a/test/Microsoft.ML.Predictor.Tests/TestPipelineSweeper.cs
+++ b/test/Microsoft.ML.Predictor.Tests/TestPipelineSweeper.cs
@@ -22,6 +22,12 @@ public TestPipelineSweeper(ITestOutputHelper helper)
{
}
+ protected override void InitializeCore()
+ {
+ base.InitializeCore();
+ Env.ComponentCatalog.RegisterAssembly(typeof(AutoInference).Assembly);
+ }
+
[Fact]
public void PipelineSweeperBasic()
{
diff --git a/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs b/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs
index 4de917921b..fdf2f50168 100644
--- a/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs
+++ b/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs
@@ -7,7 +7,6 @@
using System;
using System.Collections.Generic;
using System.IO;
-using Microsoft.ML.Runtime.CommandLine;
namespace Microsoft.ML.Runtime.RunTests
{
@@ -15,6 +14,8 @@ namespace Microsoft.ML.Runtime.RunTests
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.FastTree;
using Microsoft.ML.Runtime.FastTree.Internal;
+ using Microsoft.ML.Runtime.LightGBM;
+ using Microsoft.ML.Runtime.SymSgd;
using Microsoft.ML.TestFramework;
using System.Linq;
using System.Runtime.InteropServices;
@@ -27,6 +28,20 @@ namespace Microsoft.ML.Runtime.RunTests
///
public sealed partial class TestPredictors : BaseTestPredictors
{
+ protected override void InitializeCore()
+ {
+ base.InitializeCore();
+ InitializeEnvironment(Env);
+ }
+
+ protected override void InitializeEnvironment(IHostEnvironment environment)
+ {
+ base.InitializeEnvironment(environment);
+
+ environment.ComponentCatalog.RegisterAssembly(typeof(LightGbmBinaryPredictor).Assembly);
+ environment.ComponentCatalog.RegisterAssembly(typeof(SymSgdClassificationTrainer).Assembly);
+ }
+
///
/// Get a list of datasets for binary classifier base test.
///
diff --git a/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs b/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs
index f3c21feabf..e4ce5fac50 100644
--- a/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs
+++ b/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs
@@ -108,7 +108,8 @@ public void Init()
string logPath = Path.Combine(logDir, FullTestName + LogSuffix);
LogWriter = OpenWriter(logPath);
_passed = true;
- Env = new ConsoleEnvironment(42, outWriter: LogWriter, errWriter: LogWriter);
+ Env = new ConsoleEnvironment(42, outWriter: LogWriter, errWriter: LogWriter)
+ .AddStandardComponents();
InitializeCore();
}
diff --git a/test/Microsoft.ML.TestFramework/BaseTestPredictorsMaml.cs b/test/Microsoft.ML.TestFramework/BaseTestPredictorsMaml.cs
index ddd9dfb8da..1b9eed9fe5 100644
--- a/test/Microsoft.ML.TestFramework/BaseTestPredictorsMaml.cs
+++ b/test/Microsoft.ML.TestFramework/BaseTestPredictorsMaml.cs
@@ -295,7 +295,7 @@ protected void RunResultProcessorTest(string[] dataFiles, string outPath, string
if (extraArgs != null)
args.AddRange(extraArgs);
- ResultProcessor.Main(args.ToArray());
+ ResultProcessor.Main(Env, args.ToArray());
}
private static string GetNamePrefix(string testType, PredictorAndArgs predictor, TestDataset dataset, string extraTag = "")
diff --git a/test/Microsoft.ML.TestFramework/DataPipe/Parquet.cs b/test/Microsoft.ML.TestFramework/DataPipe/Parquet.cs
index 26f265bf9e..b3b68c5a8a 100644
--- a/test/Microsoft.ML.TestFramework/DataPipe/Parquet.cs
+++ b/test/Microsoft.ML.TestFramework/DataPipe/Parquet.cs
@@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
+using Microsoft.ML.Runtime.Data;
using System;
using Xunit;
@@ -9,6 +10,11 @@ namespace Microsoft.ML.Runtime.RunTests
{
public sealed partial class TestParquet : TestDataPipeBase
{
+ protected override void InitializeCore()
+ {
+ base.InitializeCore();
+ Env.ComponentCatalog.RegisterAssembly(typeof(ParquetLoader).Assembly);
+ }
[Fact]
public void TestParquetPrimitiveDataTypes()
diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs
index 0ec0e35341..e1b4986f06 100644
--- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs
+++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs
@@ -311,7 +311,7 @@ protected IDataLoader TestCore(string pathData, bool keepHidden, string[] argsPi
protected IDataLoader CreatePipeDataLoader(IHostEnvironment env, string pathData, string[] argsPipe, out MultiFileSource files)
{
- VerifyArgParsing(argsPipe);
+ VerifyArgParsing(env, argsPipe);
// Default to breast-cancer.txt.
if (string.IsNullOrEmpty(pathData))
@@ -350,7 +350,7 @@ protected void TestApplyTransformsToData(IHostEnvironment env, IDataLoader pipe,
Failed();
}
- protected void VerifyArgParsing(string[] strs)
+ protected void VerifyArgParsing(IHostEnvironment env, string[] strs)
{
string str = CmdParser.CombineSettings(strs);
var args = new CompositeDataLoader.Arguments();
@@ -361,18 +361,18 @@ protected void VerifyArgParsing(string[] strs)
}
// For the loader and each transform, verify that custom unparsing is correct.
- VerifyCustArgs(args.Loader);
+ VerifyCustArgs(env, args.Loader);
foreach (var kvp in args.Transform)
- VerifyCustArgs(kvp.Value);
+ VerifyCustArgs(env, kvp.Value);
}
- protected void VerifyCustArgs(IComponentFactory factory)
+ protected void VerifyCustArgs(IHostEnvironment env, IComponentFactory factory)
where TRes : class
{
if (factory is ICommandLineComponentFactory commandLineFactory)
{
var str = commandLineFactory.GetSettingsString();
- var info = ComponentCatalog.GetLoadableClassInfo(commandLineFactory.Name, commandLineFactory.SignatureType);
+ var info = env.ComponentCatalog.GetLoadableClassInfo(commandLineFactory.Name, commandLineFactory.SignatureType);
Assert.NotNull(info);
var def = info.CreateArguments();
diff --git a/test/Microsoft.ML.TestFramework/EnvironmentExtensions.cs b/test/Microsoft.ML.TestFramework/EnvironmentExtensions.cs
new file mode 100644
index 0000000000..986e82a9c4
--- /dev/null
+++ b/test/Microsoft.ML.TestFramework/EnvironmentExtensions.cs
@@ -0,0 +1,31 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Ensemble;
+using Microsoft.ML.Runtime.FastTree;
+using Microsoft.ML.Runtime.KMeans;
+using Microsoft.ML.Runtime.Learners;
+using Microsoft.ML.Runtime.PCA;
+
+namespace Microsoft.ML.TestFramework
+{
+ public static class EnvironmentExtensions
+ {
+ public static TEnvironment AddStandardComponents(this TEnvironment env)
+ where TEnvironment : IHostEnvironment
+ {
+ env.ComponentCatalog.RegisterAssembly(typeof(TextLoader).Assembly); // ML.Data
+ env.ComponentCatalog.RegisterAssembly(typeof(LinearPredictor).Assembly); // ML.StandardLearners
+ env.ComponentCatalog.RegisterAssembly(typeof(CategoricalTransform).Assembly); // ML.Transforms
+ env.ComponentCatalog.RegisterAssembly(typeof(FastTreeBinaryPredictor).Assembly); // ML.FastTree
+ env.ComponentCatalog.RegisterAssembly(typeof(EnsemblePredictor).Assembly); // ML.Ensemble
+ env.ComponentCatalog.RegisterAssembly(typeof(KMeansPredictor).Assembly); // ML.KMeansClustering
+ env.ComponentCatalog.RegisterAssembly(typeof(PcaPredictor).Assembly); // ML.PCA
+ env.ComponentCatalog.RegisterAssembly(typeof(Experiment).Assembly); // ML.Legacy
+ return env;
+ }
+ }
+}
diff --git a/test/Microsoft.ML.TestFramework/Microsoft.ML.TestFramework.csproj b/test/Microsoft.ML.TestFramework/Microsoft.ML.TestFramework.csproj
index ad50df8042..b191c08a91 100644
--- a/test/Microsoft.ML.TestFramework/Microsoft.ML.TestFramework.csproj
+++ b/test/Microsoft.ML.TestFramework/Microsoft.ML.TestFramework.csproj
@@ -7,11 +7,15 @@
+
+
+
+
diff --git a/test/Microsoft.ML.TestFramework/TestCommandBase.cs b/test/Microsoft.ML.TestFramework/TestCommandBase.cs
index a0c72a15ab..10b0e9552c 100644
--- a/test/Microsoft.ML.TestFramework/TestCommandBase.cs
+++ b/test/Microsoft.ML.TestFramework/TestCommandBase.cs
@@ -14,6 +14,7 @@
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Model;
using Microsoft.ML.Runtime.Tools;
+using Microsoft.ML.TestFramework;
using Xunit;
using Xunit.Abstractions;
@@ -270,6 +271,11 @@ public virtual OutputPath MetricsPath()
}
}
+ protected virtual void InitializeEnvironment(IHostEnvironment environment)
+ {
+ environment.AddStandardComponents();
+ }
+
///
/// Runs a command with some arguments. Note that the input
/// objects are used for comparison only.
@@ -283,6 +289,8 @@ protected bool TestCore(RunContextBase ctx, string cmdName, string args, params
using (var newWriter = OpenWriter(outputPath.Path))
using (var env = new ConsoleEnvironment(42, outWriter: newWriter, errWriter: newWriter))
{
+ InitializeEnvironment(env);
+
int res;
res = MainForTest(env, newWriter, string.Format("{0} {1}", cmdName, args), ctx.BaselineProgress);
if (res != 0)
diff --git a/test/Microsoft.ML.Tests/ImagesTests.cs b/test/Microsoft.ML.Tests/ImagesTests.cs
index adc229e328..c5aabae99d 100644
--- a/test/Microsoft.ML.Tests/ImagesTests.cs
+++ b/test/Microsoft.ML.Tests/ImagesTests.cs
@@ -29,8 +29,21 @@ public void TestEstimatorChain()
{
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
- var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile));
- var invalidData = env.CreateLoader("Text{col=ImagePath:R4:0}", new MultiFileSource(dataFile));
+ var data = TextLoader.Create(env, new TextLoader.Arguments()
+ {
+ Column = new[]
+ {
+ new TextLoader.Column("ImagePath", DataKind.TX, 0),
+ new TextLoader.Column("Name", DataKind.TX, 1),
+ }
+ }, new MultiFileSource(dataFile));
+ var invalidData = TextLoader.Create(env, new TextLoader.Arguments()
+ {
+ Column = new[]
+ {
+ new TextLoader.Column("ImagePath", DataKind.R4, 0),
+ }
+ }, new MultiFileSource(dataFile));
var pipe = new ImageLoaderEstimator(env, imageFolder, ("ImagePath", "ImageReal"))
.Append(new ImageResizerEstimator(env, "ImageReal", "ImageReal", 100, 100))
@@ -49,7 +62,14 @@ public void TestEstimatorSaveLoad()
{
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
- var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile));
+ var data = TextLoader.Create(env, new TextLoader.Arguments()
+ {
+ Column = new[]
+ {
+ new TextLoader.Column("ImagePath", DataKind.TX, 0),
+ new TextLoader.Column("Name", DataKind.TX, 1),
+ }
+ }, new MultiFileSource(dataFile));
var pipe = new ImageLoaderEstimator(env, imageFolder, ("ImagePath", "ImageReal"))
.Append(new ImageResizerEstimator(env, "ImageReal", "ImageReal", 100, 100))
@@ -82,7 +102,14 @@ public void TestSaveImages()
{
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
- var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile));
+ var data = TextLoader.Create(env, new TextLoader.Arguments()
+ {
+ Column = new[]
+ {
+ new TextLoader.Column("ImagePath", DataKind.TX, 0),
+ new TextLoader.Column("Name", DataKind.TX, 1),
+ }
+ }, new MultiFileSource(dataFile));
var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
{
Column = new ImageLoaderTransform.Column[1]
@@ -129,7 +156,14 @@ public void TestGreyscaleTransformImages()
var imageWidth = 100;
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
- var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile));
+ var data = TextLoader.Create(env, new TextLoader.Arguments()
+ {
+ Column = new[]
+ {
+ new TextLoader.Column("ImagePath", DataKind.TX, 0),
+ new TextLoader.Column("Name", DataKind.TX, 1),
+ }
+ }, new MultiFileSource(dataFile));
var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
{
Column = new ImageLoaderTransform.Column[1]
@@ -192,7 +226,14 @@ public void TestBackAndForthConversionWithAlphaInterleave()
var imageWidth = 130;
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
- var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile));
+ var data = TextLoader.Create(env, new TextLoader.Arguments()
+ {
+ Column = new[]
+ {
+ new TextLoader.Column("ImagePath", DataKind.TX, 0),
+ new TextLoader.Column("Name", DataKind.TX, 1),
+ }
+ }, new MultiFileSource(dataFile));
var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
{
Column = new ImageLoaderTransform.Column[1]
@@ -275,7 +316,14 @@ public void TestBackAndForthConversionWithoutAlphaInterleave()
var imageWidth = 130;
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
- var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile));
+ var data = TextLoader.Create(env, new TextLoader.Arguments()
+ {
+ Column = new[]
+ {
+ new TextLoader.Column("ImagePath", DataKind.TX, 0),
+ new TextLoader.Column("Name", DataKind.TX, 1),
+ }
+ }, new MultiFileSource(dataFile));
var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
{
Column = new ImageLoaderTransform.Column[1]
@@ -358,7 +406,14 @@ public void TestBackAndForthConversionWithAlphaNoInterleave()
var imageWidth = 130;
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
- var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile));
+ var data = TextLoader.Create(env, new TextLoader.Arguments()
+ {
+ Column = new[]
+ {
+ new TextLoader.Column("ImagePath", DataKind.TX, 0),
+ new TextLoader.Column("Name", DataKind.TX, 1),
+ }
+ }, new MultiFileSource(dataFile));
var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
{
Column = new ImageLoaderTransform.Column[1]
@@ -441,7 +496,14 @@ public void TestBackAndForthConversionWithoutAlphaNoInterleave()
var imageWidth = 130;
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
- var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile));
+ var data = TextLoader.Create(env, new TextLoader.Arguments()
+ {
+ Column = new[]
+ {
+ new TextLoader.Column("ImagePath", DataKind.TX, 0),
+ new TextLoader.Column("Name", DataKind.TX, 1),
+ }
+ }, new MultiFileSource(dataFile));
var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
{
Column = new ImageLoaderTransform.Column[1]
@@ -524,7 +586,14 @@ public void TestBackAndForthConversionWithAlphaInterleaveNoOffset()
var imageWidth = 130;
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
- var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile));
+ var data = TextLoader.Create(env, new TextLoader.Arguments()
+ {
+ Column = new[]
+ {
+ new TextLoader.Column("ImagePath", DataKind.TX, 0),
+ new TextLoader.Column("Name", DataKind.TX, 1),
+ }
+ }, new MultiFileSource(dataFile));
var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
{
Column = new ImageLoaderTransform.Column[1]
@@ -603,7 +672,14 @@ public void TestBackAndForthConversionWithoutAlphaInterleaveNoOffset()
var imageWidth = 130;
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
- var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile));
+ var data = TextLoader.Create(env, new TextLoader.Arguments()
+ {
+ Column = new[]
+ {
+ new TextLoader.Column("ImagePath", DataKind.TX, 0),
+ new TextLoader.Column("Name", DataKind.TX, 1),
+ }
+ }, new MultiFileSource(dataFile));
var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
{
Column = new ImageLoaderTransform.Column[1]
@@ -682,7 +758,14 @@ public void TestBackAndForthConversionWithAlphaNoInterleaveNoOffset()
var imageWidth = 130;
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
- var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile));
+ var data = TextLoader.Create(env, new TextLoader.Arguments()
+ {
+ Column = new[]
+ {
+ new TextLoader.Column("ImagePath", DataKind.TX, 0),
+ new TextLoader.Column("Name", DataKind.TX, 1),
+ }
+ }, new MultiFileSource(dataFile));
var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
{
Column = new ImageLoaderTransform.Column[1]
@@ -761,7 +844,14 @@ public void TestBackAndForthConversionWithoutAlphaNoInterleaveNoOffset()
var imageWidth = 130;
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
- var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile));
+ var data = TextLoader.Create(env, new TextLoader.Arguments()
+ {
+ Column = new[]
+ {
+ new TextLoader.Column("ImagePath", DataKind.TX, 0),
+ new TextLoader.Column("Name", DataKind.TX, 1),
+ }
+ }, new MultiFileSource(dataFile));
var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
{
Column = new ImageLoaderTransform.Column[1]
diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/DecomposableTrainAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/DecomposableTrainAndPredict.cs
index 0d20c937c3..68857999e8 100644
--- a/test/Microsoft.ML.Tests/Scenarios/Api/DecomposableTrainAndPredict.cs
+++ b/test/Microsoft.ML.Tests/Scenarios/Api/DecomposableTrainAndPredict.cs
@@ -6,6 +6,7 @@
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Learners;
using Microsoft.ML.Runtime.RunTests;
+using Microsoft.ML.TestFramework;
using System.Linq;
using Xunit;
@@ -25,7 +26,8 @@ public partial class ApiScenariosTests
[Fact]
void DecomposableTrainAndPredict()
{
- using (var env = new LocalEnvironment())
+ using (var env = new LocalEnvironment()
+ .AddStandardComponents()) // ScoreUtils.GetScorer requires scorers to be registered in the ComponentCatalog
{
var loader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.irisData.trainFilename)));
var term = TermTransform.Create(env, loader, "Label");
diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs
index 6e937ce238..68d7994842 100644
--- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs
+++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs
@@ -6,6 +6,7 @@
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Learners;
using Microsoft.ML.Runtime.RunTests;
+using Microsoft.ML.TestFramework;
using System.Linq;
using Xunit;
@@ -25,7 +26,8 @@ public partial class ApiScenariosTests
void New_DecomposableTrainAndPredict()
{
var dataPath = GetDataPath(TestDatasets.irisData.trainFilename);
- using (var env = new LocalEnvironment())
+ using (var env = new LocalEnvironment()
+ .AddStandardComponents()) // ScoreUtils.GetScorer requires scorers to be registered in the ComponentCatalog
{
var loader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath));
var term = TermTransform.Create(env, loader, "Label");
diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs
index c5a6a40703..aa425e3f4e 100644
--- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs
+++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs
@@ -52,7 +52,8 @@ private static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
- loaderSignature: LoaderSignature);
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(LoaderWrapper).Assembly.FullName);
}
public LoaderWrapper(IHostEnvironment env, ModelLoadContext ctx)
diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Extensibility.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Extensibility.cs
index f77dac19ba..15e0824c7d 100644
--- a/test/Microsoft.ML.Tests/Scenarios/Api/Extensibility.cs
+++ b/test/Microsoft.ML.Tests/Scenarios/Api/Extensibility.cs
@@ -2,6 +2,7 @@
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Learners;
using Microsoft.ML.Runtime.RunTests;
+using Microsoft.ML.TestFramework;
using System;
using System.Linq;
using Xunit;
@@ -19,7 +20,8 @@ public partial class ApiScenariosTests
[Fact]
void Extensibility()
{
- using (var env = new LocalEnvironment())
+ using (var env = new LocalEnvironment()
+ .AddStandardComponents()) // ScoreUtils.GetScorer requires scorers to be registered in the ComponentCatalog
{
var loader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.irisData.trainFilename)));
Action action = (i, j) =>
diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
index 71f5f95f33..609ccf9c1c 100644
--- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
+++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
@@ -420,7 +420,14 @@ public void TensorFlowTransformCifar()
var imageWidth = 32;
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
- var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile));
+ var data = TextLoader.Create(env, new TextLoader.Arguments()
+ {
+ Column = new[]
+ {
+ new TextLoader.Column("ImagePath", DataKind.TX, 0),
+ new TextLoader.Column("Name", DataKind.TX, 1),
+ }
+ }, new MultiFileSource(dataFile));
var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
{
Column = new ImageLoaderTransform.Column[1]
@@ -474,8 +481,14 @@ public void TensorFlowTransformCifarInvalidShape()
var imageWidth = 28;
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
- var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile));
-
+ var data = TextLoader.Create(env, new TextLoader.Arguments()
+ {
+ Column = new[]
+ {
+ new TextLoader.Column("ImagePath", DataKind.TX, 0),
+ new TextLoader.Column("Name", DataKind.TX, 1),
+ }
+ }, new MultiFileSource(dataFile));
var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
{
Column = new ImageLoaderTransform.Column[1]
diff --git a/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs b/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs
index d102ea758c..6ea74d8dbb 100644
--- a/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs
+++ b/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs
@@ -134,7 +134,7 @@ void TestCommandLine()
{
using (var env = new ConsoleEnvironment())
{
- Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=a:R4:0-3 col=b:R4:0-3} xf=TFTransform{inputs=a inputs=b outputs=c model={model_matmul/frozen_saved_model.pb}} in=f:\2.txt" }), (int)0);
+ Assert.Equal(0, Maml.Main(new[] { @"showschema loader=Text{col=a:R4:0-3 col=b:R4:0-3} xf=TFTransform{inputs=a inputs=b outputs=c model={model_matmul/frozen_saved_model.pb}} in=f:\2.txt" }));
}
}
diff --git a/test/Microsoft.ML.Tests/TermEstimatorTests.cs b/test/Microsoft.ML.Tests/TermEstimatorTests.cs
index c69c91d23c..93c827ff30 100644
--- a/test/Microsoft.ML.Tests/TermEstimatorTests.cs
+++ b/test/Microsoft.ML.Tests/TermEstimatorTests.cs
@@ -154,7 +154,7 @@ void TestCommandLine()
{
using (var env = new ConsoleEnvironment())
{
- Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Term{col=B:A} in=f:\2.txt" }), (int)0);
+ Assert.Equal(0, Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Term{col=B:A} in=f:\2.txt" }));
}
}
diff --git a/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs b/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs
index b995f7c984..e4746ccffd 100644
--- a/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs
+++ b/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs
@@ -298,7 +298,7 @@ private void ValidateBagMetadata(IDataView result)
[Fact]
public void TestCommandLine()
{
- Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Cat{col=B:A} in=f:\2.txt" }), (int)0);
+ Assert.Equal(0, Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Cat{col=B:A} in=f:\2.txt" }));
}
[Fact]
diff --git a/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs b/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs
index 56036bbc02..c5a9620c2b 100644
--- a/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs
+++ b/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs
@@ -151,7 +151,7 @@ private void ValidateMetadata(IDataView result)
[Fact]
public void TestCommandLine()
{
- Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Term{col=B:A} xf=KeyToBinary{col=C:B} in=f:\2.txt" }), (int)0);
+ Assert.Equal(0, Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Term{col=B:A} xf=KeyToBinary{col=C:B} in=f:\2.txt" }));
}
[Fact]
diff --git a/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs b/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs
index faf8b47e79..a460d9482b 100644
--- a/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs
+++ b/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs
@@ -216,7 +216,7 @@ private void ValidateMetadata(IDataView result)
[Fact]
public void TestCommandLine()
{
- Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Term{col=B:A} xf=KeyToVector{col=C:B col={name=D source=B bag+}} in=f:\2.txt" }), (int)0);
+ Assert.Equal(0, Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Term{col=B:A} xf=KeyToVector{col=C:B col={name=D source=B bag+}} in=f:\2.txt" }));
}
[Fact]