Skip to content

Creation of components through MLContext and cleanup (OneHotHash, Hash, CopyCol, KeyToVector) #2364

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/EntryPoints/SchemaManipulation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public static CommonOutputs.TransformOutput SelectColumns(IHostEnvironment env,
}

[TlcModule.EntryPoint(Name = "Transforms.ColumnCopier", Desc = "Duplicates columns from the dataset", UserName = ColumnCopyingTransformer.UserName, ShortName = ColumnCopyingTransformer.ShortName)]
public static CommonOutputs.TransformOutput CopyColumns(IHostEnvironment env, ColumnCopyingTransformer.Arguments input)
public static CommonOutputs.TransformOutput CopyColumns(IHostEnvironment env, ColumnCopyingTransformer.Options input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("CopyColumns");
Expand Down
4 changes: 4 additions & 0 deletions src/Microsoft.ML.Data/Properties/AssemblyInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TensorFlow" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TimeSeries" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Transforms" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.DnnImageFeaturizer.AlexNet" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.DnnImageFeaturizer.ResNet101" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.DnnImageFeaturizer.ResNet18" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.DnnImageFeaturizer.ResNet50" + PublicKey.Value)]

[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.StaticPipe" + PublicKey.Value)]

Expand Down
6 changes: 3 additions & 3 deletions src/Microsoft.ML.Data/TrainCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,11 @@ private void EnsureStratificationColumn(ref IDataView data, ref string stratific
// Generate a new column with the hashed stratification column.
while (data.Schema.TryGetColumnIndex(stratificationColumn, out tmp))
stratificationColumn = string.Format("{0}_{1:000}", origStratCol, ++inc);
HashingTransformer.ColumnInfo columnInfo;
HashingEstimator.ColumnInfo columnInfo;
if (seed.HasValue)
columnInfo = new HashingTransformer.ColumnInfo(stratificationColumn, origStratCol, 30, seed.Value);
columnInfo = new HashingEstimator.ColumnInfo(stratificationColumn, origStratCol, 30, seed.Value);
else
columnInfo = new HashingTransformer.ColumnInfo(stratificationColumn, origStratCol, 30);
columnInfo = new HashingEstimator.ColumnInfo(stratificationColumn, origStratCol, 30);
data = new HashingEstimator(Host, columnInfo).Fit(data).Transform(data);
}
}
Expand Down
30 changes: 21 additions & 9 deletions src/Microsoft.ML.Data/Transforms/ColumnCopying.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
using Microsoft.ML.Transforms;

[assembly: LoadableClass(ColumnCopyingTransformer.Summary, typeof(IDataTransform), typeof(ColumnCopyingTransformer),
typeof(ColumnCopyingTransformer.Arguments), typeof(SignatureDataTransform),
typeof(ColumnCopyingTransformer.Options), typeof(SignatureDataTransform),
ColumnCopyingTransformer.UserName, "CopyColumns", "CopyColumnsTransform", ColumnCopyingTransformer.ShortName,
DocName = "transform/CopyColumnsTransformer.md")]

Expand All @@ -33,18 +33,27 @@

namespace Microsoft.ML.Transforms
{
/// <summary>
/// <see cref="ColumnCopyingEstimator"/> copies the input column to another column named as specified in the parameters of the transformation.
/// </summary>
public sealed class ColumnCopyingEstimator : TrivialEstimator<ColumnCopyingTransformer>
{
public ColumnCopyingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName) :
[BestFriend]
internal ColumnCopyingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName) :
this(env, (outputColumnName, inputColumnName))
{
}

public ColumnCopyingEstimator(IHostEnvironment env, params (string outputColumnName, string inputColumnName)[] columns)
[BestFriend]
internal ColumnCopyingEstimator(IHostEnvironment env, params (string outputColumnName, string inputColumnName)[] columns)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ColumnCopyingEstimator)), new ColumnCopyingTransformer(env, columns))
{
}

/// <summary>
/// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer.
/// Used for schema propagation and verification in a pipeline.
/// </summary>
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
Expand All @@ -69,6 +78,9 @@ public sealed class ColumnCopyingTransformer : OneToOneTransformerBase
internal const string UserName = "Copy Columns Transform";
internal const string ShortName = "Copy";

/// <summary>
/// Names of output and input column pairs on which the transformation is applied.
/// </summary>
public IReadOnlyCollection<(string outputColumnName, string inputColumnName)> Columns => ColumnPairs.AsReadOnly();

private static VersionInfo GetVersionInfo()
Expand All @@ -82,12 +94,12 @@ private static VersionInfo GetVersionInfo()
loaderAssemblyName: typeof(ColumnCopyingTransformer).Assembly.FullName);
}

public ColumnCopyingTransformer(IHostEnvironment env, params (string outputColumnName, string inputColumnName)[] columns)
internal ColumnCopyingTransformer(IHostEnvironment env, params (string outputColumnName, string inputColumnName)[] columns)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ColumnCopyingTransformer)), columns)
{
}

public sealed class Column : OneToOneColumn
internal sealed class Column : OneToOneColumn
{
internal static Column Parse(string str)
{
Expand All @@ -106,20 +118,20 @@ internal bool TryUnparse(StringBuilder sb)
}
}

public sealed class Arguments : TransformInputBase
internal sealed class Options : TransformInputBase
{
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)",
Name = "Column", ShortName = "col", SortOrder = 1)]
public Column[] Columns;
}

// Factory method corresponding to SignatureDataTransform.
internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(args, nameof(args));
env.CheckValue(options, nameof(options));

var transformer = new ColumnCopyingTransformer(env, args.Columns.Select(x => (x.Name, x.Source)).ToArray());
var transformer = new ColumnCopyingTransformer(env, options.Columns.Select(x => (x.Name, x.Source)).ToArray());
return transformer.MakeDataTransform(input);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public static HashingEstimator Hash(this TransformsCatalog.ConversionTransforms
/// </summary>
/// <param name="catalog">The transform's catalog.</param>
/// <param name="columns">Description of dataset columns and how to process them.</param>
public static HashingEstimator Hash(this TransformsCatalog.ConversionTransforms catalog, params HashingTransformer.ColumnInfo[] columns)
public static HashingEstimator Hash(this TransformsCatalog.ConversionTransforms catalog, params HashingEstimator.ColumnInfo[] columns)
=> new HashingEstimator(CatalogUtils.GetEnvironment(catalog), columns);

/// <summary>
Expand Down Expand Up @@ -93,7 +93,7 @@ public static KeyToValueMappingEstimator MapKeyToValue(this TransformsCatalog.Co
/// <param name="catalog">The categorical transform's catalog.</param>
/// <param name="columns">The input column to map back to vectors.</param>
public static KeyToVectorMappingEstimator MapKeyToVector(this TransformsCatalog.ConversionTransforms catalog,
params KeyToVectorMappingTransformer.ColumnInfo[] columns)
params KeyToVectorMappingEstimator.ColumnInfo[] columns)
=> new KeyToVectorMappingEstimator(CatalogUtils.GetEnvironment(catalog), columns);

/// <summary>
Expand Down
Loading