Skip to content

Initial replacement of SubComponent with IComponentFactory #622

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Aug 8, 2018
7 changes: 7 additions & 0 deletions src/Microsoft.ML.Core/CommandLine/ArgumentAttribute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public enum VisibilityType
private string _specialPurpose;
private VisibilityType _visibility;
private string _name;
private Type _signatureType;

/// <summary>
/// Allows control of command line parsing.
Expand Down Expand Up @@ -139,5 +140,11 @@ public bool IsRequired
{
get { return ArgumentType.Required == (_type & ArgumentType.Required); }
}

public Type SignatureType
Copy link
Contributor

@Zruty0 Zruty0 Aug 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SignatureType [](start = 20, length = 13)

Does it have to be here? Let me see why it would be needed #Closed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooh, I see now. You need to tell the arg parser what's the signature type to invoke the ctor properly.
Hmm, this is unfortunate. Are you sure there's no better way?


In reply to: 208054268 [](ancestors = 208054268)

Copy link
Member Author

@eerhardt eerhardt Aug 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't found one yet...

We could have some sort of mapping from IComponentFactory<...> to SignatureType. But that would seem fragile to me.

My other thinking (that seems to be aligned with @TomFinley's above) is that we should be moving away from SignatureTypes. This at least puts the Signature type into the attribute, and out of the API directly. Hopefully, eventually we can remove signature types all together. #Pending

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, the fact that it's prominent in the attribute makes the whole approach less desirable. But I think it's an OK intermediate step


In reply to: 208058930 [](ancestors = 208058930)

{
get { return _signatureType; }
set { _signatureType = value; }
}
}
}
271 changes: 255 additions & 16 deletions src/Microsoft.ML.Core/CommandLine/CmdParser.cs

Large diffs are not rendered by default.

31 changes: 27 additions & 4 deletions src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -832,10 +832,15 @@ public static LoadableClassInfo[] FindLoadableClasses<TArgs, TSig>()

public static LoadableClassInfo GetLoadableClassInfo<TSig>(string loadName)
{
Contracts.CheckParam(typeof(TSig).BaseType == typeof(MulticastDelegate), nameof(TSig), "TSig must be a delegate type");
return GetLoadableClassInfo(loadName, typeof(TSig));
}

public static LoadableClassInfo GetLoadableClassInfo(string loadName, Type signatureType)
{
Contracts.CheckParam(signatureType.BaseType == typeof(MulticastDelegate), nameof(signatureType), "signatureType must be a delegate type");
Contracts.CheckValueOrNull(loadName);
loadName = (loadName ?? "").ToLowerInvariant().Trim();
return FindClassCore(new LoadableClassInfo.Key(loadName, typeof(TSig)));
return FindClassCore(new LoadableClassInfo.Key(loadName, signatureType));
}

public static LoadableClassInfo GetLoadableClassInfo<TRes, TSig>(SubComponent<TRes, TSig> sub)
Expand Down Expand Up @@ -886,6 +891,18 @@ public static TRes CreateInstance<TRes, TSig>(this SubComponent<TRes, TSig> comp
throw Contracts.Except("Unknown loadable class: {0}", comp.Kind).MarkSensitive(MessageSensitivity.None);
}

/// <summary>
/// Create an instance of the indicated component with the given extra parameters.
/// </summary>
public static TRes CreateInstance<TRes>(IHostEnvironment env, Type signatureType, string name, string options, params object[] extra)
where TRes : class
{
TRes result;
if (TryCreateInstance(env, signatureType, out result, name, options, extra))
return result;
throw Contracts.Except("Unknown loadable class: {0}", name).MarkSensitive(MessageSensitivity.None);
}

/// <summary>
/// Try to create an instance of the indicated component with the given extra parameters. If there is no
/// such component in the catalog, returns false. Any other error results in an exception.
Expand Down Expand Up @@ -913,13 +930,19 @@ public static bool TryCreateInstance<TRes, TSig>(IHostEnvironment env, out TRes
/// </summary>
public static bool TryCreateInstance<TRes, TSig>(IHostEnvironment env, out TRes result, string name, string options, params object[] extra)
where TRes : class
{
return TryCreateInstance<TRes>(env, typeof(TSig), out result, name, options, extra);
}

private static bool TryCreateInstance<TRes>(IHostEnvironment env, Type signatureType, out TRes result, string name, string options, params object[] extra)
where TRes : class
{
Contracts.CheckValue(env, nameof(env));
env.Check(typeof(TSig).BaseType == typeof(MulticastDelegate));
env.Check(signatureType.BaseType == typeof(MulticastDelegate));
env.CheckValueOrNull(name);

string nameLower = (name ?? "").ToLowerInvariant().Trim();
LoadableClassInfo info = FindClassCore(new LoadableClassInfo.Key(nameLower, typeof(TSig)));
LoadableClassInfo info = FindClassCore(new LoadableClassInfo.Key(nameLower, signatureType));
if (info == null)
{
result = null;
Expand Down
20 changes: 20 additions & 0 deletions src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,26 @@ public interface IComponentFactory<in TArg1, out TComponent> : IComponentFactory
TComponent CreateComponent(IHostEnvironment env, TArg1 argument1);
}

/// <summary>
/// A class for creating a component when we take one extra parameter
/// (and an <see cref="IHostEnvironment"/>) that simply wraps a delegate which
/// creates the component.
/// </summary>
public class SimpleComponentFactory<TArg1, TComponent> : IComponentFactory<TArg1, TComponent>
{
private Func<IHostEnvironment, TArg1, TComponent> _factory;

public SimpleComponentFactory(Func<IHostEnvironment, TArg1, TComponent> factory)
{
_factory = factory;
}

public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1)
{
return _factory(env, argument1);
}
}

/// <summary>
/// An interface for creating a component when we take two extra parameters (and an <see cref="IHostEnvironment"/>).
/// </summary>
Expand Down
23 changes: 13 additions & 10 deletions src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using Microsoft.ML.Runtime.Command;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Calibration;
using Microsoft.ML.Runtime.Internal.Utilities;

Expand Down Expand Up @@ -69,8 +70,8 @@ public sealed class Arguments : DataCommand.ArgumentsBase
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Whether we should cache input training data", ShortName = "cache")]
public bool? CacheData;

[Argument(ArgumentType.Multiple, HelpText = "Transforms to apply prior to splitting the data into folds", ShortName = "prexf")]
public KeyValuePair<string, SubComponent<IDataTransform, SignatureDataTransform>>[] PreTransform;
[Argument(ArgumentType.Multiple, HelpText = "Transforms to apply prior to splitting the data into folds", ShortName = "prexf", SignatureType = typeof(SignatureDataTransform))]
public KeyValuePair<string, IComponentFactory<IDataView, IDataTransform>>[] PreTransform;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not necessarily how I'd prefer to do these things, for two reasons. Firstly, while IComponentFactory is an interface, usually when appearing in these arguments objects we actually are indicating a specific inheriting interface appropriate for this task... e.g., I would expect to see something like IDataTransformFactory.

Also: our goal would be to also get rid of this little "signature" concept.

Signature is used for two major purposes, which I list here and describe how I think they could be better accomplished.

  1. It constrains the type based on the generic type (in the case where signatures are used. e.g., we might imagine an ITrainerFactory<IBinaryPredictor> in the pre-estimator world). This would I think suffice.

  2. It indicates what parameters will be necessary to pass in to instantiate the object in the Create method or the constructor. However, in the case where the signature is indicating "globally appropriate" input (in this case, IDataTransform is always instantiated with the input data), it would be better for the subinterface to just have a method directly, without using all this reflection based stuff at all.

I would have an actual subinterface type (that is, IDataTransformFactory or something similarly named. Then this interface would itself have a method to create the transform, and that method would have at least two parameters (in the current world): the env and the input data.

Copy link
Member Author

@eerhardt eerhardt Aug 1, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also: our goal would be to also get rid of this little "signature" concept.

I 100% agree. But it is too baked into the system right now to remove it in a single commit. Moving to IComponentFactory will lessen the dependence on the "signature" concept, but it will have to stay for a while until we can completely revamp the ComponentCatalog system.

usually when appearing in these arguments objects we actually are indicating a specific inheriting interface appropriate for this task... e.g., I would expect to see something like IDataTransformFactory.

I would have an actual subinterface type (that is, IDataTransformFactory or something similarly named. Then this interface would itself have a method to create the transform, and that method would have at least two parameters (in the current world): the env and the input data.

I'd need to think about this a bit more. My initial reaction is that this would cause the following disadvantages:

  1. We would have a lot of different IComponentFactory interfaces instead of the 3 we have so far. As we added new types of components, we would constantly be needing to create new interfaces.
    • I think this is analogous with delegates in .NET 1.0 and 2.0. If you needed a delegate that took in an int and returned a bool, you needed to make a brand new delegate type. But in .NET 3.5, we added Action<> and Func<> delegates. And now you barely ever need to define custom delegate types. Instead you can just use Func<int, bool>.
  2. We couldn't have the small set of SimpleComponentFactory classes that just takes a delegate to create the component. We would have to define a new class for each subinterface.
  3. Similarly as (2) above, the CmdParser class would have to (1) know about (2) define a new class, and (3) create different objects for each of the subinterfaces the API defines.
    • (Obviously I know the command line isn't a high priority, or selling reason, but we still need to support it and other tooling scenarios) #Pending

Copy link
Member Author

@eerhardt eerhardt Aug 1, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/cc @Zruty0 - any thoughts here? According to history you were the original author #Resolved

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @[email protected] 's reasoning. I think we could check it in as is.


In reply to: 207046122 [](ancestors = 207046122)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I am OK with this solution.


In reply to: 207057777 [](ancestors = 207057777)


[Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "The validation data file", ShortName = "valid")]
public string ValidationFile;
Expand Down Expand Up @@ -159,16 +160,18 @@ private void RunCore(IChannel ch, string cmd)
string name = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema, nameof(Args.NameColumn), Args.NameColumn, DefaultColumnNames.Name);
if (name == null)
{
var args = new GenerateNumberTransform.Arguments();
args.Column = new[] { new GenerateNumberTransform.Column() { Name = DefaultColumnNames.Name }, };
args.UseCounter = true;
var options = CmdParser.GetSettings(ch, args, new GenerateNumberTransform.Arguments());
preXf = preXf.Concat(
new[]
{
new KeyValuePair<string, SubComponent<IDataTransform, SignatureDataTransform>>(
"", new SubComponent<IDataTransform, SignatureDataTransform>(
GenerateNumberTransform.LoadName, options))
new KeyValuePair<string, IComponentFactory<IDataView, IDataTransform>>(
"", new SimpleComponentFactory<IDataView, IDataTransform>(
(env, input) =>
{
var args = new GenerateNumberTransform.Arguments();
args.Column = new[] { new GenerateNumberTransform.Column() { Name = DefaultColumnNames.Name }, };
args.UseCounter = true;
return new GenerateNumberTransform(env, args, input);
}))
}).ToArray();
}
}
Expand Down Expand Up @@ -263,7 +266,7 @@ private RoleMappedData ApplyAllTransformsToData(IHostEnvironment env, IChannel c
private RoleMappedData CreateRoleMappedData(IHostEnvironment env, IChannel ch, IDataView data, ITrainer trainer)
{
foreach (var kvp in Args.Transform)
data = kvp.Value.CreateInstance(env, data);
data = kvp.Value.CreateComponent(env, data);

var schema = data.Schema;
string label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.LabelColumn), Args.LabelColumn, DefaultColumnNames.Label);
Expand Down
61 changes: 41 additions & 20 deletions src/Microsoft.ML.Data/Commands/DataCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
using System.Linq;
using Microsoft.ML.Runtime.Command;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data.IO;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Model;

Expand All @@ -20,8 +22,8 @@ public static class DataCommand
{
public abstract class ArgumentsBase
{
[Argument(ArgumentType.Multiple, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "The data loader", ShortName = "loader", SortOrder = 1, NullName = "<Auto>")]
public SubComponent<IDataLoader, SignatureDataLoader> Loader;
[Argument(ArgumentType.Multiple, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "The data loader", ShortName = "loader", SortOrder = 1, NullName = "<Auto>", SignatureType = typeof(SignatureDataLoader))]
public IComponentFactory<IMultiStreamSource, IDataLoader> Loader;

[Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "The data file", ShortName = "data", SortOrder = 0)]
public string DataFile;
Expand Down Expand Up @@ -51,8 +53,8 @@ public abstract class ArgumentsBase
HelpText = "Desired degree of parallelism in the data pipeline", ShortName = "n")]
public int? Parallel;

[Argument(ArgumentType.Multiple, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Transform", ShortName = "xf")]
public KeyValuePair<string, SubComponent<IDataTransform, SignatureDataTransform>>[] Transform;
[Argument(ArgumentType.Multiple, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Transform", ShortName = "xf", SignatureType = typeof(SignatureDataTransform))]
public KeyValuePair<string, IComponentFactory<IDataView, IDataTransform>>[] Transform;
}

public abstract class ImplBase<TArgs> : ICommand
Expand Down Expand Up @@ -125,6 +127,17 @@ protected void SendTelemetryComponent(IPipe<TelemetryMessage> pipe, SubComponent
pipe.Send(TelemetryMessage.CreateTrainer(sub.Kind, sub.SubComponentSettings));
}

protected void SendTelemetryComponent(IPipe<TelemetryMessage> pipe, IComponentFactory factory)
{
Host.AssertValue(pipe);
Host.AssertValueOrNull(factory);

if (factory is ICommandLineComponentFactory commandLineFactory)
pipe.Send(TelemetryMessage.CreateTrainer(commandLineFactory.Name, commandLineFactory.GetSettingsString()));
else
pipe.Send(TelemetryMessage.CreateTrainer("Unknown", "Non-ICommandLineComponentFactory object"));
}

protected virtual void SendTelemetryCore(IPipe<TelemetryMessage> pipe)
{
Contracts.AssertValue(pipe);
Expand Down Expand Up @@ -212,9 +225,9 @@ protected void SaveLoader(IDataLoader loader, string path)
LoaderUtils.SaveLoader(loader, file);
}

protected IDataLoader CreateAndSaveLoader(string defaultLoader = "TextLoader")
protected IDataLoader CreateAndSaveLoader(Func<IHostEnvironment, IMultiStreamSource, IDataLoader> defaultLoaderFactory = null)
{
var loader = CreateLoader(defaultLoader);
var loader = CreateLoader(defaultLoaderFactory);
if (!string.IsNullOrWhiteSpace(Args.OutputModelFile))
{
using (var file = Host.CreateOutputFile(Args.OutputModelFile))
Expand Down Expand Up @@ -268,12 +281,12 @@ protected void LoadModelObjects(
}

// Next create the loader.
var sub = Args.Loader;
var loaderFactory = Args.Loader;
IDataLoader trainPipe = null;
if (sub.IsGood())
if (loaderFactory != null)
{
// The loader is overridden from the command line.
pipe = sub.CreateInstance(Host, new MultiFileSource(Args.DataFile));
pipe = loaderFactory.CreateComponent(Host, new MultiFileSource(Args.DataFile));
if (Args.LoadTransforms == true)
{
Host.CheckUserArg(!string.IsNullOrWhiteSpace(Args.InputModelFile), nameof(Args.InputModelFile));
Expand Down Expand Up @@ -316,9 +329,9 @@ protected void LoadModelObjects(
}
}

protected IDataLoader CreateLoader(string defaultLoader = "TextLoader")
protected IDataLoader CreateLoader(Func<IHostEnvironment, IMultiStreamSource, IDataLoader> defaultLoaderFactory = null)
{
var loader = CreateRawLoader(defaultLoader);
var loader = CreateRawLoader(defaultLoaderFactory);
loader = CreateTransformChain(loader);
return loader;
}
Expand All @@ -328,13 +341,15 @@ private IDataLoader CreateTransformChain(IDataLoader loader)
return CompositeDataLoader.Create(Host, loader, Args.Transform);
}

protected IDataLoader CreateRawLoader(string defaultLoader = "TextLoader", string dataFile = null)
protected IDataLoader CreateRawLoader(
Func<IHostEnvironment, IMultiStreamSource, IDataLoader> defaultLoaderFactory = null,
string dataFile = null)
{
if (string.IsNullOrWhiteSpace(dataFile))
dataFile = Args.DataFile;

IDataLoader loader;
if (!string.IsNullOrWhiteSpace(Args.InputModelFile) && !Args.Loader.IsGood())
if (!string.IsNullOrWhiteSpace(Args.InputModelFile) && Args.Loader == null)
{
// Load the loader from the data model.
using (var file = Host.OpenInputFile(Args.InputModelFile))
Expand All @@ -345,21 +360,27 @@ protected IDataLoader CreateRawLoader(string defaultLoader = "TextLoader", strin
else
{
// Either there is no input model file, or there is, but the loader is overridden.
var sub = Args.Loader;
if (!sub.IsGood())
IMultiStreamSource fileSource = new MultiFileSource(dataFile);
var loaderFactory = Args.Loader;
if (loaderFactory == null)
{
var ext = Path.GetExtension(dataFile);
var isText =
string.Equals(ext, ".txt", StringComparison.OrdinalIgnoreCase) ||
string.Equals(ext, ".tlc", StringComparison.OrdinalIgnoreCase);
var isBinary = string.Equals(ext, ".idv", StringComparison.OrdinalIgnoreCase);
var isTranspose = string.Equals(ext, ".tdv", StringComparison.OrdinalIgnoreCase);
sub =
new SubComponent<IDataLoader, SignatureDataLoader>(
isText ? "TextLoader" : isBinary ? "BinaryLoader" : isTranspose ? "TransposeLoader" : defaultLoader);
}

loader = sub.CreateInstance(Host, new MultiFileSource(dataFile));
return isText ? new TextLoader(Host, new TextLoader.Arguments(), fileSource) :
isBinary ? new BinaryLoader(Host, new BinaryLoader.Arguments(), fileSource) :
isTranspose ? new TransposeLoader(Host, new TransposeLoader.Arguments(), fileSource) :
defaultLoaderFactory != null ? defaultLoaderFactory(Host, fileSource) :
new TextLoader(Host, new TextLoader.Arguments(), fileSource);
}
else
{
loader = loaderFactory.CreateComponent(Host, fileSource);
}

if (Args.LoadTransforms == true)
{
Expand Down
3 changes: 2 additions & 1 deletion src/Microsoft.ML.Data/Commands/EvaluateCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ private void RunCore(IChannel ch)
Host.AssertValue(ch);

ch.Trace("Creating loader");
IDataView view = CreateAndSaveLoader(IO.BinaryLoader.LoadName);
IDataView view = CreateAndSaveLoader(
(env, source) => new IO.BinaryLoader(env, new IO.BinaryLoader.Arguments(), source));

ch.Trace("Binding columns");
ISchema schema = view.Schema;
Expand Down
5 changes: 3 additions & 2 deletions src/Microsoft.ML.Data/Commands/ScoreCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using Microsoft.ML.Runtime.Command;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Model;

Expand Down Expand Up @@ -62,8 +63,8 @@ public sealed class Arguments : DataCommand.ArgumentsBase
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether to include hidden columns", ShortName = "keep")]
public bool KeepHidden;

[Argument(ArgumentType.Multiple, HelpText = "Post processing transform", ShortName = "pxf")]
public KeyValuePair<string, SubComponent<IDataTransform, SignatureDataTransform>>[] PostTransform;
[Argument(ArgumentType.Multiple, HelpText = "Post processing transform", ShortName = "pxf", SignatureType = typeof(SignatureDataTransform))]
public KeyValuePair<string, IComponentFactory<IDataView, IDataTransform>>[] PostTransform;

[Argument(ArgumentType.AtMostOnce, HelpText = "Whether to output all columns or just scores", ShortName = "all")]
public bool? OutputAllColumns;
Expand Down
Loading