diff --git a/src/Microsoft.ML.Core/CommandLine/ArgumentAttribute.cs b/src/Microsoft.ML.Core/CommandLine/ArgumentAttribute.cs index 405e207773..64fe3b5b80 100644 --- a/src/Microsoft.ML.Core/CommandLine/ArgumentAttribute.cs +++ b/src/Microsoft.ML.Core/CommandLine/ArgumentAttribute.cs @@ -34,6 +34,7 @@ public enum VisibilityType private string _specialPurpose; private VisibilityType _visibility; private string _name; + private Type _signatureType; /// /// Allows control of command line parsing. @@ -139,5 +140,11 @@ public bool IsRequired { get { return ArgumentType.Required == (_type & ArgumentType.Required); } } + + public Type SignatureType + { + get { return _signatureType; } + set { _signatureType = value; } + } } } \ No newline at end of file diff --git a/src/Microsoft.ML.Core/CommandLine/CmdParser.cs b/src/Microsoft.ML.Core/CommandLine/CmdParser.cs index eb85fcce12..191321213f 100644 --- a/src/Microsoft.ML.Core/CommandLine/CmdParser.cs +++ b/src/Microsoft.ML.Core/CommandLine/CmdParser.cs @@ -249,6 +249,18 @@ public enum SettingsFlags Default = ShortNames | NoSlashes } + /// + /// An IComponentFactory that is used in the command line. + /// + /// This allows components to be created by name, signature type, and a settings string. + /// + public interface ICommandLineComponentFactory : IComponentFactory + { + Type SignatureType { get; } + string Name { get; } + string GetSettingsString(); + } + /// /// Parser for command line arguments. /// @@ -797,7 +809,8 @@ private bool ParseArgumentList(ArgumentInfo info, string[] strs, object destinat ModuleCatalog.ComponentInfo component; if (IsCurlyGroup(value) && value.Length == 2) arg.Field.SetValue(destination, null); - else if (_catalog.Value.TryFindComponentCaseInsensitive(arg.Field.FieldType, value, out component)) + else if (!arg.IsCollection && + _catalog.Value.TryFindComponentCaseInsensitive(arg.Field.FieldType, value, out component)) { var activator = Activator.CreateInstance(component.ArgumentType); if (!IsCurlyGroup(value) && i + 1 < strs.Length && IsCurlyGroup(strs[i + 1])) @@ -810,8 +823,9 @@ private bool ParseArgumentList(ArgumentInfo info, string[] strs, object destinat } else { - Report("Error: Failed to find component with name '{0}' for option '{1}'", value, arg.LongName); - hadError |= true; + hadError |= !arg.SetValue(this, ref values[arg.Index], value, tag, destination); + if (!IsCurlyGroup(value) && i + 1 < strs.Length && IsCurlyGroup(strs[i + 1])) + hadError |= !arg.SetValue(this, ref values[arg.Index], strs[++i], "", destination); } continue; } @@ -1532,6 +1546,8 @@ private sealed class Argument // Used for help and composing settings strings. public readonly object DefaultValue; + private readonly Type _signatureType; + // For custom types. private readonly ArgumentInfo _infoCustom; private readonly ConstructorInfo _ctorCustom; @@ -1559,6 +1575,7 @@ public Argument(int index, string name, string[] nicks, object defaults, Argumen IsDefault = attr is DefaultArgumentAttribute; Contracts.Assert(!IsDefault || Utils.Size(ShortNames) == 0); IsHidden = attr.Hide; + _signatureType = attr.SignatureType; if (field.FieldType.IsArray) { @@ -1664,6 +1681,40 @@ public bool Finish(CmdParser owner, ArgValue val, object destination) Field.SetValue(destination, com); } + else if (IsSingleComponentFactory) + { + bool haveName = false; + string name = null; + string[] settings = null; + for (int i = 0; i < Utils.Size(values);) + { + string str = (string)values[i].Value; + if (str.StartsWith("{")) + { + i++; + continue; + } + if (haveName) + { + owner.Report("Duplicate component kind for argument {0}", LongName); + error = true; + } + name = str; + haveName = true; + values.RemoveAt(i); + } + + if (Utils.Size(values) > 0) + settings = values.Select(x => (string)x.Value).ToArray(); + + Contracts.Check(_signatureType != null, "ComponentFactory Arguments need a SignatureType set."); + var factory = ComponentFactoryFactory.CreateComponentFactory( + ItemType, + _signatureType, + name, + settings); + Field.SetValue(destination, factory); + } else if (IsMultiSubComponent) { // REVIEW: the kind should not be separated from settings: everything related @@ -1711,6 +1762,63 @@ public bool Finish(CmdParser owner, ArgValue val, object destination) Field.SetValue(destination, arr); } } + else if (IsMultiComponentFactory) + { + // REVIEW: the kind should not be separated from settings: everything related + // to one item should go into one value, not multiple values + if (IsTaggedCollection) + { + // Tagged collection of IComponentFactory + var comList = new List>(); + + for (int i = 0; i < Utils.Size(values);) + { + string tag = values[i].Key; + string name = (string)values[i++].Value; + string[] settings = null; + if (i < values.Count && IsCurlyGroup((string)values[i].Value) && string.IsNullOrEmpty(values[i].Key)) + settings = new string[] { (string)values[i++].Value }; + var factory = ComponentFactoryFactory.CreateComponentFactory( + ItemValueType, + _signatureType, + name, + settings); + comList.Add(new KeyValuePair(tag, factory)); + } + + var arr = Array.CreateInstance(ItemType, comList.Count); + for (int i = 0; i < arr.Length; i++) + { + var kvp = Activator.CreateInstance(ItemType, comList[i].Key, comList[i].Value); + arr.SetValue(kvp, i); + } + + Field.SetValue(destination, arr); + } + else + { + // Collection of IComponentFactory + var comList = new List(); + for (int i = 0; i < Utils.Size(values);) + { + string name = (string)values[i++].Value; + string[] settings = null; + if (i < values.Count && IsCurlyGroup((string)values[i].Value)) + settings = new string[] { (string)values[i++].Value }; + var factory = ComponentFactoryFactory.CreateComponentFactory( + ItemValueType, + _signatureType, + name, + settings); + comList.Add(factory); + } + + var arr = Array.CreateInstance(ItemValueType, comList.Count); + for (int i = 0; i < arr.Length; i++) + arr.SetValue(comList[i], i); + Field.SetValue(destination, arr); + } + } else if (IsTaggedCollection) { var res = Array.CreateInstance(ItemType, Utils.Size(values)); @@ -1732,6 +1840,118 @@ public bool Finish(CmdParser owner, ArgValue val, object destination) return error; } + /// + /// A factory class for creating IComponentFactory instances. + /// + private static class ComponentFactoryFactory + { + public static IComponentFactory CreateComponentFactory( + Type factoryType, + Type signatureType, + string name, + string[] settings) + { + Contracts.Check(factoryType != null && + typeof(IComponentFactory).IsAssignableFrom(factoryType) && + factoryType.IsGenericType); + + Type componentFactoryType; + if (factoryType.GenericTypeArguments.Length == 1) + { + componentFactoryType = typeof(ComponentFactory<>); + } + else if (factoryType.GenericTypeArguments.Length == 2) + { + componentFactoryType = typeof(ComponentFactory<,>); + } + else + { + throw Contracts.ExceptNotImpl("ComponentFactoryFactory can only create components with 1 or 2 type args."); + } + + return (IComponentFactory)Activator.CreateInstance( + componentFactoryType.MakeGenericType(factoryType.GenericTypeArguments), + signatureType, + name, + settings); + } + + private abstract class ComponentFactory : ICommandLineComponentFactory + { + public Type SignatureType { get; } + public string Name { get; } + private string[] Settings { get; } + + protected ComponentFactory(Type signatureType, string name, string[] settings) + { + SignatureType = signatureType; + Name = name; + + if (settings == null || (settings.Length == 1 && string.IsNullOrEmpty(settings[0]))) + { + settings = Array.Empty(); + } + Settings = settings; + } + + public string GetSettingsString() + { + return CombineSettings(Settings); + } + + public override string ToString() + { + if (string.IsNullOrEmpty(Name) && Settings.Length == 0) + return "{}"; + + if (Settings.Length == 0) + return Name; + + string str = CombineSettings(Settings); + StringBuilder sb = new StringBuilder(); + CmdQuoter.QuoteValue(str, sb, true); + return Name + sb.ToString(); + } + } + + private class ComponentFactory : ComponentFactory, IComponentFactory + where TComponent : class + { + public ComponentFactory(Type signatureType, string name, string[] settings) + : base(signatureType, name, settings) + { + } + + public TComponent CreateComponent(IHostEnvironment env) + { + return ComponentCatalog.CreateInstance( + env, + SignatureType, + Name, + GetSettingsString()); + } + } + + private class ComponentFactory : ComponentFactory, IComponentFactory + where TComponent : class + { + public ComponentFactory(Type signatureType, string name, string[] settings) + : base(signatureType, name, settings) + { + } + + public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1) + { + return ComponentCatalog.CreateInstance( + env, + SignatureType, + Name, + GetSettingsString(), + argument1); + } + } + } + private bool ReportMissingRequiredArgument(CmdParser owner, ArgValue val) { if (!IsRequired || val != null) @@ -1784,7 +2004,7 @@ public bool SetValue(CmdParser owner, ref ArgValue val, string value, string tag } val.Values.Add(new KeyValuePair(tag, newValue)); } - else if (IsSingleSubComponent) + else if (IsSingleSubComponent || IsComponentFactory) { Contracts.Assert(newValue is string || newValue == null); Contracts.Assert((string)newValue != ""); @@ -1834,7 +2054,7 @@ private bool ParseValue(CmdParser owner, string data, out object value) return false; } - if (IsSubComponentItemType) + if (IsSubComponentItemType || IsComponentFactory) { value = data; return true; @@ -2186,19 +2406,28 @@ private string GetString(IExceptionContext ectx, object value, StringBuilder buf string name; var catalog = ModuleCatalog.CreateInstance(ectx); var type = value.GetType(); - bool success = catalog.TryGetComponentShortName(type, out name); - Contracts.Assert(success); - - var settings = GetSettings(ectx, value, Activator.CreateInstance(type)); - buffer.Clear(); - buffer.Append(name); - if (!string.IsNullOrWhiteSpace(settings)) + bool isModuleComponent = catalog.TryGetComponentShortName(type, out name); + if (isModuleComponent) { - StringBuilder sb = new StringBuilder(); - CmdQuoter.QuoteValue(settings, sb, true); - buffer.Append(sb); + var settings = GetSettings(ectx, value, Activator.CreateInstance(type)); + buffer.Clear(); + buffer.Append(name); + if (!string.IsNullOrWhiteSpace(settings)) + { + StringBuilder sb = new StringBuilder(); + CmdQuoter.QuoteValue(settings, sb, true); + buffer.Append(sb); + } + return buffer.ToString(); + } + else if (value is ICommandLineComponentFactory) + { + return value.ToString(); + } + else + { + throw ectx.Except($"IComponentFactory instances either need to be EntryPointComponents or implement {nameof(ICommandLineComponentFactory)}."); } - return buffer.ToString(); } return value.ToString(); @@ -2344,6 +2573,16 @@ public bool IsMultiSubComponent { get { return IsSubComponentItemType && Field.FieldType.IsArray; } } + public bool IsSingleComponentFactory + { + get { return IsComponentFactory && !Field.FieldType.IsArray; } + } + + public bool IsMultiComponentFactory + { + get { return IsComponentFactory && Field.FieldType.IsArray; } + } + public bool IsCustomItemType { get { return _infoCustom != null; } } diff --git a/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs b/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs index a92c20a4e6..ddbbd2a500 100644 --- a/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs +++ b/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs @@ -832,10 +832,15 @@ public static LoadableClassInfo[] FindLoadableClasses() public static LoadableClassInfo GetLoadableClassInfo(string loadName) { - Contracts.CheckParam(typeof(TSig).BaseType == typeof(MulticastDelegate), nameof(TSig), "TSig must be a delegate type"); + return GetLoadableClassInfo(loadName, typeof(TSig)); + } + + public static LoadableClassInfo GetLoadableClassInfo(string loadName, Type signatureType) + { + Contracts.CheckParam(signatureType.BaseType == typeof(MulticastDelegate), nameof(signatureType), "signatureType must be a delegate type"); Contracts.CheckValueOrNull(loadName); loadName = (loadName ?? "").ToLowerInvariant().Trim(); - return FindClassCore(new LoadableClassInfo.Key(loadName, typeof(TSig))); + return FindClassCore(new LoadableClassInfo.Key(loadName, signatureType)); } public static LoadableClassInfo GetLoadableClassInfo(SubComponent sub) @@ -886,6 +891,18 @@ public static TRes CreateInstance(this SubComponent comp throw Contracts.Except("Unknown loadable class: {0}", comp.Kind).MarkSensitive(MessageSensitivity.None); } + /// + /// Create an instance of the indicated component with the given extra parameters. + /// + public static TRes CreateInstance(IHostEnvironment env, Type signatureType, string name, string options, params object[] extra) + where TRes : class + { + TRes result; + if (TryCreateInstance(env, signatureType, out result, name, options, extra)) + return result; + throw Contracts.Except("Unknown loadable class: {0}", name).MarkSensitive(MessageSensitivity.None); + } + /// /// Try to create an instance of the indicated component with the given extra parameters. If there is no /// such component in the catalog, returns false. Any other error results in an exception. @@ -913,13 +930,19 @@ public static bool TryCreateInstance(IHostEnvironment env, out TRes /// public static bool TryCreateInstance(IHostEnvironment env, out TRes result, string name, string options, params object[] extra) where TRes : class + { + return TryCreateInstance(env, typeof(TSig), out result, name, options, extra); + } + + private static bool TryCreateInstance(IHostEnvironment env, Type signatureType, out TRes result, string name, string options, params object[] extra) + where TRes : class { Contracts.CheckValue(env, nameof(env)); - env.Check(typeof(TSig).BaseType == typeof(MulticastDelegate)); + env.Check(signatureType.BaseType == typeof(MulticastDelegate)); env.CheckValueOrNull(name); string nameLower = (name ?? "").ToLowerInvariant().Trim(); - LoadableClassInfo info = FindClassCore(new LoadableClassInfo.Key(nameLower, typeof(TSig))); + LoadableClassInfo info = FindClassCore(new LoadableClassInfo.Key(nameLower, signatureType)); if (info == null) { result = null; diff --git a/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs b/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs index 9334f0f225..d69a9d0b93 100644 --- a/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs +++ b/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs @@ -37,6 +37,26 @@ public interface IComponentFactory : IComponentFactory TComponent CreateComponent(IHostEnvironment env, TArg1 argument1); } + /// + /// A class for creating a component when we take one extra parameter + /// (and an ) that simply wraps a delegate which + /// creates the component. + /// + public class SimpleComponentFactory : IComponentFactory + { + private Func _factory; + + public SimpleComponentFactory(Func factory) + { + _factory = factory; + } + + public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1) + { + return _factory(env, argument1); + } + } + /// /// An interface for creating a component when we take two extra parameters (and an ). /// diff --git a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs index 26ec32d3fe..affd949064 100644 --- a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs +++ b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs @@ -11,6 +11,7 @@ using Microsoft.ML.Runtime.Command; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Calibration; using Microsoft.ML.Runtime.Internal.Utilities; @@ -69,8 +70,8 @@ public sealed class Arguments : DataCommand.ArgumentsBase [Argument(ArgumentType.LastOccurenceWins, HelpText = "Whether we should cache input training data", ShortName = "cache")] public bool? CacheData; - [Argument(ArgumentType.Multiple, HelpText = "Transforms to apply prior to splitting the data into folds", ShortName = "prexf")] - public KeyValuePair>[] PreTransform; + [Argument(ArgumentType.Multiple, HelpText = "Transforms to apply prior to splitting the data into folds", ShortName = "prexf", SignatureType = typeof(SignatureDataTransform))] + public KeyValuePair>[] PreTransform; [Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "The validation data file", ShortName = "valid")] public string ValidationFile; @@ -159,16 +160,18 @@ private void RunCore(IChannel ch, string cmd) string name = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema, nameof(Args.NameColumn), Args.NameColumn, DefaultColumnNames.Name); if (name == null) { - var args = new GenerateNumberTransform.Arguments(); - args.Column = new[] { new GenerateNumberTransform.Column() { Name = DefaultColumnNames.Name }, }; - args.UseCounter = true; - var options = CmdParser.GetSettings(ch, args, new GenerateNumberTransform.Arguments()); preXf = preXf.Concat( new[] { - new KeyValuePair>( - "", new SubComponent( - GenerateNumberTransform.LoadName, options)) + new KeyValuePair>( + "", new SimpleComponentFactory( + (env, input) => + { + var args = new GenerateNumberTransform.Arguments(); + args.Column = new[] { new GenerateNumberTransform.Column() { Name = DefaultColumnNames.Name }, }; + args.UseCounter = true; + return new GenerateNumberTransform(env, args, input); + })) }).ToArray(); } } @@ -263,7 +266,7 @@ private RoleMappedData ApplyAllTransformsToData(IHostEnvironment env, IChannel c private RoleMappedData CreateRoleMappedData(IHostEnvironment env, IChannel ch, IDataView data, ITrainer trainer) { foreach (var kvp in Args.Transform) - data = kvp.Value.CreateInstance(env, data); + data = kvp.Value.CreateComponent(env, data); var schema = data.Schema; string label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.LabelColumn), Args.LabelColumn, DefaultColumnNames.Label); diff --git a/src/Microsoft.ML.Data/Commands/DataCommand.cs b/src/Microsoft.ML.Data/Commands/DataCommand.cs index 2a62d78901..8f489e270f 100644 --- a/src/Microsoft.ML.Data/Commands/DataCommand.cs +++ b/src/Microsoft.ML.Data/Commands/DataCommand.cs @@ -8,6 +8,8 @@ using System.Linq; using Microsoft.ML.Runtime.Command; using Microsoft.ML.Runtime.CommandLine; +using Microsoft.ML.Runtime.Data.IO; +using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; @@ -20,8 +22,8 @@ public static class DataCommand { public abstract class ArgumentsBase { - [Argument(ArgumentType.Multiple, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "The data loader", ShortName = "loader", SortOrder = 1, NullName = "")] - public SubComponent Loader; + [Argument(ArgumentType.Multiple, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "The data loader", ShortName = "loader", SortOrder = 1, NullName = "", SignatureType = typeof(SignatureDataLoader))] + public IComponentFactory Loader; [Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "The data file", ShortName = "data", SortOrder = 0)] public string DataFile; @@ -51,8 +53,8 @@ public abstract class ArgumentsBase HelpText = "Desired degree of parallelism in the data pipeline", ShortName = "n")] public int? Parallel; - [Argument(ArgumentType.Multiple, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Transform", ShortName = "xf")] - public KeyValuePair>[] Transform; + [Argument(ArgumentType.Multiple, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Transform", ShortName = "xf", SignatureType = typeof(SignatureDataTransform))] + public KeyValuePair>[] Transform; } public abstract class ImplBase : ICommand @@ -125,6 +127,17 @@ protected void SendTelemetryComponent(IPipe pipe, SubComponent pipe.Send(TelemetryMessage.CreateTrainer(sub.Kind, sub.SubComponentSettings)); } + protected void SendTelemetryComponent(IPipe pipe, IComponentFactory factory) + { + Host.AssertValue(pipe); + Host.AssertValueOrNull(factory); + + if (factory is ICommandLineComponentFactory commandLineFactory) + pipe.Send(TelemetryMessage.CreateTrainer(commandLineFactory.Name, commandLineFactory.GetSettingsString())); + else + pipe.Send(TelemetryMessage.CreateTrainer("Unknown", "Non-ICommandLineComponentFactory object")); + } + protected virtual void SendTelemetryCore(IPipe pipe) { Contracts.AssertValue(pipe); @@ -212,9 +225,9 @@ protected void SaveLoader(IDataLoader loader, string path) LoaderUtils.SaveLoader(loader, file); } - protected IDataLoader CreateAndSaveLoader(string defaultLoader = "TextLoader") + protected IDataLoader CreateAndSaveLoader(Func defaultLoaderFactory = null) { - var loader = CreateLoader(defaultLoader); + var loader = CreateLoader(defaultLoaderFactory); if (!string.IsNullOrWhiteSpace(Args.OutputModelFile)) { using (var file = Host.CreateOutputFile(Args.OutputModelFile)) @@ -268,12 +281,12 @@ protected void LoadModelObjects( } // Next create the loader. - var sub = Args.Loader; + var loaderFactory = Args.Loader; IDataLoader trainPipe = null; - if (sub.IsGood()) + if (loaderFactory != null) { // The loader is overridden from the command line. - pipe = sub.CreateInstance(Host, new MultiFileSource(Args.DataFile)); + pipe = loaderFactory.CreateComponent(Host, new MultiFileSource(Args.DataFile)); if (Args.LoadTransforms == true) { Host.CheckUserArg(!string.IsNullOrWhiteSpace(Args.InputModelFile), nameof(Args.InputModelFile)); @@ -316,9 +329,9 @@ protected void LoadModelObjects( } } - protected IDataLoader CreateLoader(string defaultLoader = "TextLoader") + protected IDataLoader CreateLoader(Func defaultLoaderFactory = null) { - var loader = CreateRawLoader(defaultLoader); + var loader = CreateRawLoader(defaultLoaderFactory); loader = CreateTransformChain(loader); return loader; } @@ -328,13 +341,15 @@ private IDataLoader CreateTransformChain(IDataLoader loader) return CompositeDataLoader.Create(Host, loader, Args.Transform); } - protected IDataLoader CreateRawLoader(string defaultLoader = "TextLoader", string dataFile = null) + protected IDataLoader CreateRawLoader( + Func defaultLoaderFactory = null, + string dataFile = null) { if (string.IsNullOrWhiteSpace(dataFile)) dataFile = Args.DataFile; IDataLoader loader; - if (!string.IsNullOrWhiteSpace(Args.InputModelFile) && !Args.Loader.IsGood()) + if (!string.IsNullOrWhiteSpace(Args.InputModelFile) && Args.Loader == null) { // Load the loader from the data model. using (var file = Host.OpenInputFile(Args.InputModelFile)) @@ -345,8 +360,9 @@ protected IDataLoader CreateRawLoader(string defaultLoader = "TextLoader", strin else { // Either there is no input model file, or there is, but the loader is overridden. - var sub = Args.Loader; - if (!sub.IsGood()) + IMultiStreamSource fileSource = new MultiFileSource(dataFile); + var loaderFactory = Args.Loader; + if (loaderFactory == null) { var ext = Path.GetExtension(dataFile); var isText = @@ -354,12 +370,17 @@ protected IDataLoader CreateRawLoader(string defaultLoader = "TextLoader", strin string.Equals(ext, ".tlc", StringComparison.OrdinalIgnoreCase); var isBinary = string.Equals(ext, ".idv", StringComparison.OrdinalIgnoreCase); var isTranspose = string.Equals(ext, ".tdv", StringComparison.OrdinalIgnoreCase); - sub = - new SubComponent( - isText ? "TextLoader" : isBinary ? "BinaryLoader" : isTranspose ? "TransposeLoader" : defaultLoader); - } - loader = sub.CreateInstance(Host, new MultiFileSource(dataFile)); + return isText ? new TextLoader(Host, new TextLoader.Arguments(), fileSource) : + isBinary ? new BinaryLoader(Host, new BinaryLoader.Arguments(), fileSource) : + isTranspose ? new TransposeLoader(Host, new TransposeLoader.Arguments(), fileSource) : + defaultLoaderFactory != null ? defaultLoaderFactory(Host, fileSource) : + new TextLoader(Host, new TextLoader.Arguments(), fileSource); + } + else + { + loader = loaderFactory.CreateComponent(Host, fileSource); + } if (Args.LoadTransforms == true) { diff --git a/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs b/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs index 77bdf0e32f..c122c65ffa 100644 --- a/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs +++ b/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs @@ -217,7 +217,8 @@ private void RunCore(IChannel ch) Host.AssertValue(ch); ch.Trace("Creating loader"); - IDataView view = CreateAndSaveLoader(IO.BinaryLoader.LoadName); + IDataView view = CreateAndSaveLoader( + (env, source) => new IO.BinaryLoader(env, new IO.BinaryLoader.Arguments(), source)); ch.Trace("Binding columns"); ISchema schema = view.Schema; diff --git a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs index 607bf119d7..f69c35231d 100644 --- a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs +++ b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs @@ -11,6 +11,7 @@ using Microsoft.ML.Runtime.Command; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; @@ -62,8 +63,8 @@ public sealed class Arguments : DataCommand.ArgumentsBase [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to include hidden columns", ShortName = "keep")] public bool KeepHidden; - [Argument(ArgumentType.Multiple, HelpText = "Post processing transform", ShortName = "pxf")] - public KeyValuePair>[] PostTransform; + [Argument(ArgumentType.Multiple, HelpText = "Post processing transform", ShortName = "pxf", SignatureType = typeof(SignatureDataTransform))] + public KeyValuePair>[] PostTransform; [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to output all columns or just scores", ShortName = "all")] public bool? OutputAllColumns; diff --git a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs index a2ab3a7b16..d83af5d824 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs @@ -14,6 +14,7 @@ using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML.Runtime.EntryPoints; [assembly: LoadableClass(typeof(IDataLoader), typeof(CompositeDataLoader), typeof(CompositeDataLoader.Arguments), typeof(SignatureDataLoader), "Composite Data Loader", "CompositeDataLoader", "Composite", "PipeData", "Pipe", "PipeDataLoader")] @@ -34,11 +35,11 @@ public sealed class CompositeDataLoader : IDataLoader, ITransposeDataView { public sealed class Arguments { - [Argument(ArgumentType.Multiple, HelpText = "The data loader", ShortName = "loader")] - public SubComponent Loader; + [Argument(ArgumentType.Multiple, HelpText = "The data loader", ShortName = "loader", SignatureType = typeof(SignatureDataLoader))] + public IComponentFactory Loader; - [Argument(ArgumentType.Multiple, HelpText = "Transform", ShortName = "xf")] - public KeyValuePair>[] Transform; + [Argument(ArgumentType.Multiple, HelpText = "Transform", ShortName = "xf", SignatureType = typeof(SignatureDataTransform))] + public KeyValuePair>[] Transform; } private struct TransformEx @@ -98,10 +99,10 @@ public static IDataLoader Create(IHostEnvironment env, Arguments args, IMultiStr var h = env.Register(RegistrationName); h.CheckValue(args, nameof(args)); - h.CheckUserArg(args.Loader.IsGood(), nameof(args.Loader)); + h.CheckValue(args.Loader, nameof(args.Loader)); h.CheckValue(files, nameof(files)); - var loader = args.Loader.CreateInstance(h, files); + var loader = args.Loader.CreateComponent(h, files); return CreateCore(h, loader, args.Transform); } @@ -111,7 +112,7 @@ public static IDataLoader Create(IHostEnvironment env, Arguments args, IMultiStr /// If there are no transforms, the is returned. /// public static IDataLoader Create(IHostEnvironment env, IDataLoader srcLoader, - params KeyValuePair>[] transformArgs) + params KeyValuePair>[] transformArgs) { Contracts.CheckValue(env, nameof(env)); var h = env.Register(RegistrationName); @@ -122,7 +123,7 @@ public static IDataLoader Create(IHostEnvironment env, IDataLoader srcLoader, } private static IDataLoader CreateCore(IHost host, IDataLoader srcLoader, - KeyValuePair>[] transformArgs) + KeyValuePair>[] transformArgs) { Contracts.AssertValue(host, "host"); host.AssertValue(srcLoader, "srcLoader"); @@ -131,8 +132,15 @@ private static IDataLoader CreateCore(IHost host, IDataLoader srcLoader, if (Utils.Size(transformArgs) == 0) return srcLoader; + string GetTagData(IComponentFactory factory) + { + // When coming from the command line, preserve the string arguments. + // For other factories, we aren't able to get the string. + return (factory as ICommandLineComponentFactory)?.ToString(); + } + var tagData = transformArgs - .Select(x => new KeyValuePair(x.Key, x.Value.ToString())) + .Select(x => new KeyValuePair(x.Key, GetTagData(x.Value))) .ToArray(); // Warn if tags coincide with ones already present in the loader. @@ -152,7 +160,7 @@ private static IDataLoader CreateCore(IHost host, IDataLoader srcLoader, } return ApplyTransformsCore(host, srcLoader, tagData, - (prov, index, data) => transformArgs[index].Value.CreateInstance(prov, data)); + (env, index, data) => transformArgs[index].Value.CreateComponent(env, data)); } /// diff --git a/src/Microsoft.ML.Data/Transforms/TermTransform.cs b/src/Microsoft.ML.Data/Transforms/TermTransform.cs index 6eaf48e995..6365e27a54 100644 --- a/src/Microsoft.ML.Data/Transforms/TermTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/TermTransform.cs @@ -117,8 +117,8 @@ public abstract class ArgumentsBase : TransformInputBase [Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "Data file containing the terms", ShortName = "data", SortOrder = 110, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)] public string DataFile; - [Argument(ArgumentType.Multiple, HelpText = "Data loader", NullName = "", SortOrder = 111, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)] - public SubComponent Loader; + [Argument(ArgumentType.Multiple, HelpText = "Data loader", NullName = "", SortOrder = 111, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureDataLoader))] + public IComponentFactory Loader; [Argument(ArgumentType.AtMostOnce, HelpText = "Name of the text column containing the terms", ShortName = "termCol", SortOrder = 112, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)] public string TermsColumn; @@ -309,12 +309,19 @@ private static TermMap CreateFileTermMap(IHostEnvironment env, IChannel ch, Argu string file = args.DataFile; // First column using the file. string src = args.TermsColumn; - var sub = args.Loader; + IMultiStreamSource fileSource = new MultiFileSource(file); + + var loaderFactory = args.Loader; // If the user manually specifies a loader, or this is already a pre-processed binary // file, then we assume the user knows what they're doing and do not attempt to convert // to the desired type ourselves. bool autoConvert = false; - if (!sub.IsGood()) + IDataLoader loader; + if (loaderFactory != null) + { + loader = loaderFactory.CreateComponent(env, fileSource); + } + else { // Determine the default loader from the extension. var ext = Path.GetExtension(file); @@ -326,11 +333,11 @@ private static TermMap CreateFileTermMap(IHostEnvironment env, IChannel ch, Argu ch.CheckUserArg(!string.IsNullOrWhiteSpace(src), nameof(args.TermsColumn), "Must be specified"); if (isBinary) - sub = new SubComponent("BinaryLoader"); + loader = new BinaryLoader(env, new BinaryLoader.Arguments(), fileSource); else { ch.Assert(isTranspose); - sub = new SubComponent("TransposeLoader"); + loader = new TransposeLoader(env, new TransposeLoader.Arguments(), fileSource); } } else @@ -341,7 +348,21 @@ private static TermMap CreateFileTermMap(IHostEnvironment env, IChannel ch, Argu "{0} should not be specified when default loader is TextLoader. Ignoring {0}={1}", nameof(Arguments.TermsColumn), src); } - sub = new SubComponent("TextLoader", "sep=tab col=Term:TX:0"); + loader = new TextLoader(env, + new TextLoader.Arguments() + { + Separator = "tab", + Column = new[] + { + new TextLoader.Column() + { + Name ="Term", + Type = DataKind.TX, + Source = new[] { new TextLoader.Range() { Min = 0 } } + } + } + }, + fileSource); src = "Term"; autoConvert = true; } @@ -349,8 +370,6 @@ private static TermMap CreateFileTermMap(IHostEnvironment env, IChannel ch, Argu ch.AssertNonEmpty(src); int colSrc; - var loader = sub.CreateInstance(env, new MultiFileSource(file)); - if (!loader.Schema.TryGetColumnIndex(src, out colSrc)) throw ch.ExceptUserArg(nameof(args.TermsColumn), "Unknown column '{0}'", src); var typeSrc = loader.Schema.GetColumnType(colSrc); @@ -395,7 +414,7 @@ private static TermMap[] Train(IHostEnvironment env, IChannel ch, ColInfo[] info ch.AssertValue(trainingData); if ((args.Term != null || !string.IsNullOrEmpty(args.Terms)) && - (!string.IsNullOrWhiteSpace(args.DataFile) || args.Loader.IsGood() || + (!string.IsNullOrWhiteSpace(args.DataFile) || args.Loader != null || !string.IsNullOrWhiteSpace(args.TermsColumn))) { ch.Warning("Explicit term list specified. Data file arguments will be ignored"); diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs index 52cd025370..23a81f78ee 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs @@ -10,6 +10,7 @@ using Microsoft.ML.Runtime.Internal.Calibration; using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Training; +using Microsoft.ML.Runtime.EntryPoints; namespace Microsoft.ML.Runtime.Learners { @@ -21,13 +22,12 @@ public abstract class MetaMulticlassTrainer : TrainerBase { public abstract class ArgumentsBase { - [Argument(ArgumentType.Multiple, HelpText = "Base predictor", ShortName = "p", SortOrder = 1)] + [Argument(ArgumentType.Multiple, HelpText = "Base predictor", ShortName = "p", SortOrder = 1, SignatureType = typeof(SignatureBinaryClassifierTrainer))] [TGUI(Label = "Predictor Type", Description = "Type of underlying binary predictor")] - public SubComponent PredictorType = - new SubComponent(LinearSvm.LoadNameValue); + public IComponentFactory PredictorType; - [Argument(ArgumentType.Multiple, HelpText = "Output calibrator", ShortName = "cali", NullName = "")] - public SubComponent Calibrator = new SubComponent("PlattCalibration"); + [Argument(ArgumentType.Multiple, HelpText = "Output calibrator", ShortName = "cali", NullName = "", SignatureType = typeof(SignatureCalibrator))] + public IComponentFactory Calibrator = new PlattCalibratorTrainerFactory(); [Argument(ArgumentType.LastOccurenceWins, HelpText = "Number of instances to train the calibrator", ShortName = "numcali")] public int MaxCalibrationExamples = 1000000000; @@ -47,14 +47,20 @@ internal MetaMulticlassTrainer(IHostEnvironment env, TArgs args, string name) { Host.CheckValue(args, nameof(args)); Args = args; - Host.CheckUserArg(Args.PredictorType.IsGood(), nameof(Args.PredictorType)); // Create the first trainer so errors in the args surface early. - _trainer = Args.PredictorType.CreateInstance(Host); + _trainer = CreateTrainer(); // Regarding caching, no matter what the internal predictor, we're performing many passes // simply by virtue of this being a meta-trainer, so we will still cache. Info = new TrainerInfo(normalization: _trainer.Info.NeedNormalization); } + private TScalarTrainer CreateTrainer() + { + return Args.PredictorType != null ? + Args.PredictorType.CreateComponent(Host) : + new LinearSvm(Host, new LinearSvm.Arguments()); + } + protected IDataView MapLabelsCore(ColumnType type, RefPredicate equalsTarget, RoleMappedData data, string dstName) { Host.AssertValue(type); @@ -84,7 +90,7 @@ protected TScalarTrainer GetTrainer() { // We may have instantiated the first trainer to use already, from the constructor. // If so capture it and set the retained trainer to null; otherwise create a new one. - var train = _trainer ?? Args.PredictorType.CreateInstance(Host); + var train = _trainer ?? CreateTrainer(); _trainer = null; return train; } diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs index c123411edd..24359feca3 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs @@ -89,10 +89,10 @@ private TScalarPredictor TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappe if (Args.UseProbabilities) { ICalibratorTrainer calibrator; - if (!Args.Calibrator.IsGood()) + if (Args.Calibrator == null) calibrator = null; else - calibrator = Args.Calibrator.CreateInstance(Host); + calibrator = Args.Calibrator.CreateComponent(Host); var res = CalibratorUtils.TrainCalibratorIfNeeded(Host, ch, calibrator, Args.MaxCalibrationExamples, trainer, predictor, td); predictor = res as TScalarPredictor; diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs index 193c8f0290..6434038384 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs @@ -104,10 +104,10 @@ private TDistPredictor TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedD var predictor = trainer.Train(td); ICalibratorTrainer calibrator; - if (!Args.Calibrator.IsGood()) + if (Args.Calibrator == null) calibrator = null; else - calibrator = Args.Calibrator.CreateInstance(Host); + calibrator = Args.Calibrator.CreateComponent(Host); var res = CalibratorUtils.TrainCalibratorIfNeeded(Host, ch, calibrator, Args.MaxCalibrationExamples, trainer, predictor, td); var dist = res as TDistPredictor; diff --git a/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs b/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs index fb56fa2e0a..9d6f5564cf 100644 --- a/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs @@ -451,8 +451,8 @@ public sealed class TermLoaderArguments [Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "Data file containing the terms", ShortName = "data", SortOrder = 2, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)] public string DataFile; - [Argument(ArgumentType.Multiple, HelpText = "Data loader", NullName = "", SortOrder = 3, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)] - public SubComponent Loader; + [Argument(ArgumentType.Multiple, HelpText = "Data loader", NullName = "", SortOrder = 3, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureDataLoader))] + public IComponentFactory Loader; [Argument(ArgumentType.AtMostOnce, HelpText = "Name of the text column containing the terms", ShortName = "termCol", SortOrder = 4, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)] public string TermsColumn; diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs index 29da956c62..b53062c1a8 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs @@ -8,6 +8,7 @@ using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.IO; +using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; using Xunit; @@ -215,28 +216,35 @@ protected void VerifyArgParsing(string[] strs) VerifyCustArgs(kvp.Value); } - protected void VerifyCustArgs(SubComponent sub) + protected void VerifyCustArgs(IComponentFactory factory) where TRes : class { - var str = CmdParser.CombineSettings(sub.Settings); - var info = ComponentCatalog.GetLoadableClassInfo(sub.Kind); - Assert.NotNull(info); - var def = info.CreateArguments(); + if (factory is ICommandLineComponentFactory commandLineFactory) + { + var str = commandLineFactory.GetSettingsString(); + var info = ComponentCatalog.GetLoadableClassInfo(commandLineFactory.Name, commandLineFactory.SignatureType); + Assert.NotNull(info); + var def = info.CreateArguments(); - var a1 = info.CreateArguments(); - CmdParser.ParseArguments(Env, str, a1); + var a1 = info.CreateArguments(); + CmdParser.ParseArguments(Env, str, a1); - // Get both the expanded and custom forms. - string exp1 = CmdParser.GetSettings(Env, a1, def, SettingsFlags.Default | SettingsFlags.NoUnparse); - string cust = CmdParser.GetSettings(Env, a1, def); + // Get both the expanded and custom forms. + string exp1 = CmdParser.GetSettings(Env, a1, def, SettingsFlags.Default | SettingsFlags.NoUnparse); + string cust = CmdParser.GetSettings(Env, a1, def); - // Map cust back to an object, then get its full form. - var a2 = info.CreateArguments(); - CmdParser.ParseArguments(Env, cust, a2); - string exp2 = CmdParser.GetSettings(Env, a2, def, SettingsFlags.Default | SettingsFlags.NoUnparse); + // Map cust back to an object, then get its full form. + var a2 = info.CreateArguments(); + CmdParser.ParseArguments(Env, cust, a2); + string exp2 = CmdParser.GetSettings(Env, a2, def, SettingsFlags.Default | SettingsFlags.NoUnparse); - if (exp1 != exp2) - Fail("Custom unparse failed on '{0}' starting with '{1}': '{2}' vs '{3}'", sub.Kind, str, exp1, exp2); + if (exp1 != exp2) + Fail("Custom unparse failed on '{0}' starting with '{1}': '{2}' vs '{3}'", commandLineFactory.Name, str, exp1, exp2); + } + else + { + Fail($"TestDataPipeBase was called with a non command line loader or transform '{factory}'"); + } } protected bool SaveLoadText(IDataView view, IHostEnvironment env,