Skip to content

Commit 58ff9b5

Browse files
authored
Creation of components through MLContext and cleanup (Concat, Normalizer, NA Indicator/Replace) (#2363)
1 parent d7eb0a6 commit 58ff9b5

File tree

20 files changed

+271
-181
lines changed

20 files changed

+271
-181
lines changed

src/Microsoft.ML.Data/EntryPoints/SchemaManipulation.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ namespace Microsoft.ML.EntryPoints
1414
internal static class SchemaManipulation
1515
{
1616
[TlcModule.EntryPoint(Name = "Transforms.ColumnConcatenator", Desc = ColumnConcatenatingTransformer.Summary, UserName = ColumnConcatenatingTransformer.UserName, ShortName = ColumnConcatenatingTransformer.LoadName)]
17-
public static CommonOutputs.TransformOutput ConcatColumns(IHostEnvironment env, ColumnConcatenatingTransformer.Arguments input)
17+
public static CommonOutputs.TransformOutput ConcatColumns(IHostEnvironment env, ColumnConcatenatingTransformer.Options input)
1818
{
1919
Contracts.CheckValue(env, nameof(env));
2020
var host = env.Register("ConcatColumns");

src/Microsoft.ML.Data/Transforms/ColumnBindingsBase.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ protected ColumnBindingsBase(Schema input, bool user, params string[] names)
313313
// standard column name.
314314
const string standardColumnArgName = "Columns";
315315
Contracts.Assert(nameof(ValueToKeyMappingTransformer.Options.Columns) == standardColumnArgName);
316-
Contracts.Assert(nameof(ColumnConcatenatingTransformer.Arguments.Columns) == standardColumnArgName);
316+
Contracts.Assert(nameof(ColumnConcatenatingTransformer.Options.Columns) == standardColumnArgName);
317317

318318
for (int iinfo = 0; iinfo < names.Length; iinfo++)
319319
{

src/Microsoft.ML.Data/Transforms/ColumnConcatenatingEstimator.cs

+14-4
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111

1212
namespace Microsoft.ML.Transforms
1313
{
14-
public sealed class ColumnConcatenatingEstimator : IEstimator<ITransformer>
14+
/// <summary>
15+
/// Concatenates columns in an <see cref="IDataView"/> into one single column. Estimator for the <see cref="ColumnConcatenatingTransformer"/>.
16+
/// </summary>
17+
public sealed class ColumnConcatenatingEstimator : IEstimator<ColumnConcatenatingTransformer>
1518
{
1619
private readonly IHost _host;
1720
private readonly string _name;
@@ -22,8 +25,8 @@ public sealed class ColumnConcatenatingEstimator : IEstimator<ITransformer>
2225
/// </summary>
2326
/// <param name="env">The local instance of <see cref="IHostEnvironment"/>.</param>
2427
/// <param name="outputColumnName">The name of the resulting column.</param>
25-
/// <param name="inputColumnNames">The columns to concatenate together.</param>
26-
public ColumnConcatenatingEstimator(IHostEnvironment env, string outputColumnName, params string[] inputColumnNames)
28+
/// <param name="inputColumnNames">The columns to concatenate into one single column.</param>
29+
internal ColumnConcatenatingEstimator(IHostEnvironment env, string outputColumnName, params string[] inputColumnNames)
2730
{
2831
Contracts.CheckValue(env, nameof(env));
2932
_host = env.Register("ColumnConcatenatingEstimator ");
@@ -37,7 +40,10 @@ public ColumnConcatenatingEstimator(IHostEnvironment env, string outputColumnNam
3740
_source = inputColumnNames;
3841
}
3942

40-
public ITransformer Fit(IDataView input)
43+
/// <summary>
44+
/// Trains and returns a <see cref="ColumnConcatenatingTransformer"/>.
45+
/// </summary>
46+
public ColumnConcatenatingTransformer Fit(IDataView input)
4147
{
4248
_host.CheckValue(input, nameof(input));
4349
return new ColumnConcatenatingTransformer(_host, _name, _source);
@@ -109,6 +115,10 @@ private SchemaShape.Column CheckInputsAndMakeColumn(
109115
return new SchemaShape.Column(name, vecKind, itemType, false, new SchemaShape(meta));
110116
}
111117

118+
/// <summary>
119+
/// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer.
120+
/// Used for schema propagation and verification in a pipeline.
121+
/// </summary>
112122
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
113123
{
114124
_host.CheckValue(inputSchema, nameof(inputSchema));

src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs

+35-23
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515
using Microsoft.ML.Model;
1616
using Microsoft.ML.Model.Onnx;
1717
using Microsoft.ML.Model.Pfa;
18+
using Microsoft.ML.Transforms;
1819
using Newtonsoft.Json.Linq;
1920

20-
[assembly: LoadableClass(ColumnConcatenatingTransformer.Summary, typeof(IDataTransform), typeof(ColumnConcatenatingTransformer), typeof(ColumnConcatenatingTransformer.TaggedArguments), typeof(SignatureDataTransform),
21+
[assembly: LoadableClass(ColumnConcatenatingTransformer.Summary, typeof(IDataTransform), typeof(ColumnConcatenatingTransformer), typeof(ColumnConcatenatingTransformer.TaggedOptions), typeof(SignatureDataTransform),
2122
ColumnConcatenatingTransformer.UserName, ColumnConcatenatingTransformer.LoadName, "ConcatTransform", DocName = "transform/ConcatTransform.md")]
2223

2324
[assembly: LoadableClass(ColumnConcatenatingTransformer.Summary, typeof(IDataTransform), typeof(ColumnConcatenatingTransformer), null, typeof(SignatureLoadDataTransform),
@@ -33,6 +34,10 @@ namespace Microsoft.ML.Data
3334
{
3435
using PfaType = PfaUtils.Type;
3536

37+
/// <summary>
38+
/// Concatenates columns in an <see cref="IDataView"/> into one single column. Please see <see cref="ColumnConcatenatingEstimator"/> for
39+
/// constructing <see cref="ColumnConcatenatingTransformer"/>.
40+
/// </summary>
3641
public sealed class ColumnConcatenatingTransformer : RowToRowTransformerBase
3742
{
3843
internal const string Summary = "Concatenates one or more columns of the same item type.";
@@ -42,7 +47,7 @@ public sealed class ColumnConcatenatingTransformer : RowToRowTransformerBase
4247
internal const string LoaderSignature = "ConcatTransform";
4348
internal const string LoaderSignatureOld = "ConcatFunction";
4449

45-
public sealed class Column : ManyToOneColumn
50+
internal sealed class Column : ManyToOneColumn
4651
{
4752
internal static Column Parse(string str)
4853
{
@@ -60,7 +65,8 @@ internal bool TryUnparse(StringBuilder sb)
6065
}
6166
}
6267

63-
public sealed class TaggedColumn
68+
[BestFriend]
69+
internal sealed class TaggedColumn
6470
{
6571
[Argument(ArgumentType.AtMostOnce, HelpText = "Name of the new column", ShortName = "name")]
6672
public string Name;
@@ -99,13 +105,13 @@ internal bool TryUnparse(StringBuilder sb)
99105
}
100106
}
101107

102-
public sealed class Arguments : TransformInputBase
108+
internal sealed class Options : TransformInputBase
103109
{
104-
public Arguments()
110+
public Options()
105111
{
106112
}
107113

108-
public Arguments(string name, params string[] source)
114+
public Options(string name, params string[] source)
109115
{
110116
Columns = new[] { new Column()
111117
{
@@ -119,14 +125,16 @@ public Arguments(string name, params string[] source)
119125
public Column[] Columns;
120126
}
121127

122-
public sealed class TaggedArguments
128+
[BestFriend]
129+
internal sealed class TaggedOptions
123130
{
124131
[Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:srcs)",
125132
Name = "Column", ShortName = "col", SortOrder = 1)]
126133
public TaggedColumn[] Columns;
127134
}
128135

129-
public sealed class ColumnInfo
136+
[BestFriend]
137+
internal sealed class ColumnInfo
130138
{
131139
public readonly string Name;
132140
private readonly (string name, string alias)[] _sources;
@@ -212,22 +220,26 @@ internal ColumnInfo(ModelLoadContext ctx)
212220

213221
private readonly ColumnInfo[] _columns;
214222

215-
public IReadOnlyCollection<ColumnInfo> Columns => _columns.AsReadOnly();
223+
/// <summary>
224+
/// The names of the output and input column pairs for the transformation.
225+
/// </summary>
226+
public IReadOnlyCollection<(string outputColumnName, string[] inputColumnNames)> Columns
227+
=> _columns.Select(col => (outputColumnName: col.Name, inputColumnNames: col.Sources.Select(source => source.name).ToArray())).ToArray().AsReadOnly();
216228

217229
/// <summary>
218230
/// Concatename columns in <paramref name="inputColumnNames"/> into one column <paramref name="outputColumnName"/>.
219231
/// Original columns are also preserved.
220232
/// The column types must match, and the output column type is always a vector.
221233
/// </summary>
222-
public ColumnConcatenatingTransformer(IHostEnvironment env, string outputColumnName, params string[] inputColumnNames)
234+
internal ColumnConcatenatingTransformer(IHostEnvironment env, string outputColumnName, params string[] inputColumnNames)
223235
: this(env, new ColumnInfo(outputColumnName, inputColumnNames))
224236
{
225237
}
226238

227239
/// <summary>
228240
/// Concatenates multiple groups of columns, each group is denoted by one of <paramref name="columns"/>.
229241
/// </summary>
230-
public ColumnConcatenatingTransformer(IHostEnvironment env, params ColumnInfo[] columns) :
242+
internal ColumnConcatenatingTransformer(IHostEnvironment env, params ColumnInfo[] columns) :
231243
base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ColumnConcatenatingTransformer)))
232244
{
233245
Contracts.CheckValue(columns, nameof(columns));
@@ -357,17 +369,17 @@ private ColumnInfo[] LoadLegacy(ModelLoadContext ctx)
357369
///<summary>
358370
/// Factory method for SignatureDataTransform.
359371
/// </summary>
360-
internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
372+
internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
361373
{
362374
Contracts.CheckValue(env, nameof(env));
363-
env.CheckValue(args, nameof(args));
375+
env.CheckValue(options, nameof(options));
364376
env.CheckValue(input, nameof(input));
365-
env.CheckUserArg(Utils.Size(args.Columns) > 0, nameof(args.Columns));
377+
env.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns));
366378

367-
for (int i = 0; i < args.Columns.Length; i++)
368-
env.CheckUserArg(Utils.Size(args.Columns[i].Source) > 0, nameof(args.Columns));
379+
for (int i = 0; i < options.Columns.Length; i++)
380+
env.CheckUserArg(Utils.Size(options.Columns[i].Source) > 0, nameof(options.Columns));
369381

370-
var cols = args.Columns
382+
var cols = options.Columns
371383
.Select(c => new ColumnInfo(c.Name, c.Source))
372384
.ToArray();
373385
var transformer = new ColumnConcatenatingTransformer(env, cols);
@@ -377,17 +389,17 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat
377389
/// Factory method corresponding to SignatureDataTransform.
378390
/// </summary>
379391
[BestFriend]
380-
internal static IDataTransform Create(IHostEnvironment env, TaggedArguments args, IDataView input)
392+
internal static IDataTransform Create(IHostEnvironment env, TaggedOptions options, IDataView input)
381393
{
382394
Contracts.CheckValue(env, nameof(env));
383-
env.CheckValue(args, nameof(args));
395+
env.CheckValue(options, nameof(options));
384396
env.CheckValue(input, nameof(input));
385-
env.CheckUserArg(Utils.Size(args.Columns) > 0, nameof(args.Columns));
397+
env.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns));
386398

387-
for (int i = 0; i < args.Columns.Length; i++)
388-
env.CheckUserArg(Utils.Size(args.Columns[i].Source) > 0, nameof(args.Columns));
399+
for (int i = 0; i < options.Columns.Length; i++)
400+
env.CheckUserArg(Utils.Size(options.Columns[i].Source) > 0, nameof(options.Columns));
389401

390-
var cols = args.Columns
402+
var cols = options.Columns
391403
.Select(c => new ColumnInfo(c.Name, c.Source.Select(kvp => (kvp.Value, kvp.Key != "" ? kvp.Key : null))))
392404
.ToArray();
393405
var transformer = new ColumnConcatenatingTransformer(env, cols);

src/Microsoft.ML.Data/Transforms/ExtensionsCatalog.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ public static ColumnCopyingEstimator CopyColumns(this TransformsCatalog catalog,
4545
=> new ColumnCopyingEstimator(CatalogUtils.GetEnvironment(catalog), columns);
4646

4747
/// <summary>
48-
/// Concatenates two columns together.
48+
/// Concatenates columns together.
4949
/// </summary>
5050
/// <param name="catalog">The transform's catalog.</param>
5151
/// <param name="outputColumnName">Name of the column resulting from the transformation of <paramref name="inputColumnNames"/>.</param>

src/Microsoft.ML.Data/Transforms/Normalizer.cs

+13-7
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
using Microsoft.ML.Model.Pfa;
1717
using Microsoft.ML.Transforms.Normalizers;
1818
using Newtonsoft.Json.Linq;
19-
using static Microsoft.ML.Transforms.Normalizers.NormalizeTransform;
2019

2120
[assembly: LoadableClass(typeof(NormalizingTransformer), null, typeof(SignatureLoadModel),
2221
"", NormalizingTransformer.LoaderSignature)]
@@ -206,7 +205,7 @@ internal override IColumnFunctionBuilder MakeBuilder(IHost host, int srcIndex, C
206205
/// <param name="inputColumnName">Name of the column to transform.
207206
/// If set to <see langword="null"/>, the value of the <paramref name="outputColumnName"/> will be used as source.</param>
208207
/// <param name="mode">The <see cref="NormalizerMode"/> indicating how to the old values are mapped to the new values.</param>
209-
public NormalizingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, NormalizerMode mode = NormalizerMode.MinMax)
208+
internal NormalizingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, NormalizerMode mode = NormalizerMode.MinMax)
210209
: this(env, mode, (outputColumnName, inputColumnName ?? outputColumnName))
211210
{
212211
}
@@ -217,7 +216,7 @@ public NormalizingEstimator(IHostEnvironment env, string outputColumnName, strin
217216
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
218217
/// <param name="mode">The <see cref="NormalizerMode"/> indicating how to the old values are mapped to the new values.</param>
219218
/// <param name="columns">An array of (outputColumnName, inputColumnName) tuples.</param>
220-
public NormalizingEstimator(IHostEnvironment env, NormalizerMode mode, params (string outputColumnName, string inputColumnName)[] columns)
219+
internal NormalizingEstimator(IHostEnvironment env, NormalizerMode mode, params (string outputColumnName, string inputColumnName)[] columns)
221220
{
222221
Contracts.CheckValue(env, nameof(env));
223222
_host = env.Register(nameof(NormalizingEstimator));
@@ -230,7 +229,7 @@ public NormalizingEstimator(IHostEnvironment env, NormalizerMode mode, params (s
230229
/// </summary>
231230
/// <param name="env">The private instance of the <see cref="IHostEnvironment"/>.</param>
232231
/// <param name="columns">An array of <see cref="ColumnBase"/> defining the inputs to the Normalizer, and their settings.</param>
233-
public NormalizingEstimator(IHostEnvironment env, params ColumnBase[] columns)
232+
internal NormalizingEstimator(IHostEnvironment env, params ColumnBase[] columns)
234233
{
235234
Contracts.CheckValue(env, nameof(env));
236235
_host = env.Register(nameof(NormalizingEstimator));
@@ -239,12 +238,19 @@ public NormalizingEstimator(IHostEnvironment env, params ColumnBase[] columns)
239238
_columns = columns.ToArray();
240239
}
241240

241+
/// <summary>
242+
/// Trains and returns a <see cref="NormalizingTransformer"/>.
243+
/// </summary>
242244
public NormalizingTransformer Fit(IDataView input)
243245
{
244246
_host.CheckValue(input, nameof(input));
245247
return NormalizingTransformer.Train(_host, input, _columns);
246248
}
247249

250+
/// <summary>
251+
/// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer.
252+
/// Used for schema propagation and verification in a pipeline.
253+
/// </summary>
248254
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
249255
{
250256
_host.CheckValue(inputSchema, nameof(inputSchema));
@@ -275,7 +281,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
275281

276282
public sealed partial class NormalizingTransformer : OneToOneTransformerBase
277283
{
278-
public const string LoaderSignature = "Normalizer";
284+
internal const string LoaderSignature = "Normalizer";
279285

280286
internal const string LoaderSignatureOld = "NormalizeFunction";
281287

@@ -387,7 +393,7 @@ private NormalizingTransformer(IHostEnvironment env, ColumnInfo[] columns)
387393
ColumnFunctions = new ColumnFunctionAccessor(Columns);
388394
}
389395

390-
public static NormalizingTransformer Train(IHostEnvironment env, IDataView data, NormalizingEstimator.ColumnBase[] columns)
396+
internal static NormalizingTransformer Train(IHostEnvironment env, IDataView data, NormalizingEstimator.ColumnBase[] columns)
391397
{
392398
Contracts.CheckValue(env, nameof(env));
393399
env.CheckValue(data, nameof(data));
@@ -510,7 +516,7 @@ private NormalizingTransformer(IHost host, ModelLoadContext ctx, IDataView input
510516
Columns = ImmutableArray.Create(cols);
511517
}
512518

513-
public static NormalizingTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
519+
private static NormalizingTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
514520
{
515521
Contracts.CheckValue(env, nameof(env));
516522
env.CheckValue(ctx, nameof(ctx));

src/Microsoft.ML.EntryPoints/FeatureCombiner.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ public static CommonOutputs.TransformOutput PrepareFeatures(IHostEnvironment env
7474
// as a group id. That's just one example - you get the idea.
7575
string nameFeat = DefaultColumnNames.Features;
7676
viewTrain = ColumnConcatenatingTransformer.Create(host,
77-
new ColumnConcatenatingTransformer.TaggedArguments()
77+
new ColumnConcatenatingTransformer.TaggedOptions()
7878
{
7979
Columns =
8080
new[] { new ColumnConcatenatingTransformer.TaggedColumn() { Name = nameFeat, Source = concatNames.ToArray() } }

0 commit comments

Comments
 (0)