diff --git a/src/Microsoft.ML.Data/Transforms/ConversionsExtensionsCatalog.cs b/src/Microsoft.ML.Data/Transforms/ConversionsExtensionsCatalog.cs index 896a7c6aa9..ad0e460c20 100644 --- a/src/Microsoft.ML.Data/Transforms/ConversionsExtensionsCatalog.cs +++ b/src/Microsoft.ML.Data/Transforms/ConversionsExtensionsCatalog.cs @@ -22,7 +22,8 @@ public static class ConversionsExtensionsCatalog /// </summary> /// <param name="catalog">The transform's catalog.</param> /// <param name="outputColumnName">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param> - /// <param name="inputColumnName">Name of the column to transform. If set to <see langword="null"/>, the value of the <paramref name="outputColumnName"/> will be used as source.</param> /// <param name="hashBits">Number of bits to hash into. Must be between 1 and 31, inclusive.</param> + /// <param name="inputColumnName">Name of the column to transform. If set to <see langword="null"/>, the value of the <paramref name="outputColumnName"/> will be used as source.</param> + /// <param name="hashBits">Number of bits to hash into. Must be between 1 and 31, inclusive.</param> /// <param name="invertHash">During hashing we constuct mappings between original values and the produced hash values. /// Text representation of original values are stored in the slot names of the metadata for the new column.Hashing, as such, can map many initial values to one. /// <paramref name="invertHash"/> specifies the upper bound of the number of distinct input values mapping to a hash that should be retained. @@ -55,7 +56,7 @@ public static TypeConvertingEstimator ConvertType(this TransformsCatalog.Convers /// </summary> /// <param name="catalog">The transform's catalog.</param> /// <param name="columns">Description of dataset columns and how to process them.</param> - public static TypeConvertingEstimator ConvertType(this TransformsCatalog.ConversionTransforms catalog, params TypeConvertingTransformer.ColumnInfo[] columns) + public static TypeConvertingEstimator ConvertType(this TransformsCatalog.ConversionTransforms catalog, params TypeConvertingEstimator.ColumnInfo[] columns) => new TypeConvertingEstimator(CatalogUtils.GetEnvironment(catalog), columns); /// <summary> diff --git a/src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransform.cs b/src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransformer.cs similarity index 94% rename from src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransform.cs rename to src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransformer.cs index dd49ed484c..461a3f61e9 100644 --- a/src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransformer.cs @@ -76,7 +76,7 @@ namespace Microsoft.ML.Data /// </example> public sealed class FeatureContributionCalculatingTransformer : OneToOneTransformerBase { - public sealed class Arguments : TransformInputBase + internal sealed class Options : TransformInputBase { [Argument(ArgumentType.Required, HelpText = "The predictor model to apply to data", SortOrder = 1)] public PredictorModel PredictorModel; @@ -99,9 +99,9 @@ public sealed class Arguments : TransformInputBase internal const string FriendlyName = "Feature Contribution Calculation"; internal const string LoaderSignature = "FeatureContribution"; - public readonly int Top; - public readonly int Bottom; - public readonly bool Normalize; + internal readonly int Top; + internal readonly int Bottom; + internal readonly bool Normalize; private readonly IFeatureContributionMapper _predictor; @@ -128,7 +128,7 @@ private static VersionInfo GetVersionInfo() /// <param name="numNegativeContributions">The number of negative contributions to report, sorted from highest magnitude to lowest magnitude. /// Note that if there are fewer features with negative contributions than <paramref name="numNegativeContributions"/>, the rest will be returned as zeros.</param> /// <param name="normalize">Whether the feature contributions should be normalized to the [-1, 1] interval.</param> - public FeatureContributionCalculatingTransformer(IHostEnvironment env, ICalculateFeatureContribution modelParameters, + internal FeatureContributionCalculatingTransformer(IHostEnvironment env, ICalculateFeatureContribution modelParameters, string featureColumn = DefaultColumnNames.Features, int numPositiveContributions = FeatureContributionCalculatingEstimator.Defaults.NumPositiveContributions, int numNegativeContributions = FeatureContributionCalculatingEstimator.Defaults.NumNegativeContributions, @@ -281,7 +281,7 @@ public sealed class FeatureContributionCalculatingEstimator : TrivialEstimator<F private readonly string _featureColumn; private readonly ICalculateFeatureContribution _predictor; - public static class Defaults + internal static class Defaults { public const int NumPositiveContributions = 10; public const int NumNegativeContributions = 10; @@ -300,7 +300,7 @@ public static class Defaults /// <param name="numNegativeContributions">The number of negative contributions to report, sorted from highest magnitude to lowest magnitude. /// Note that if there are fewer features with negative contributions than <paramref name="numNegativeContributions"/>, the rest will be returned as zeros.</param> /// <param name="normalize">Whether the feature contributions should be normalized to the [-1, 1] interval.</param> - public FeatureContributionCalculatingEstimator(IHostEnvironment env, ICalculateFeatureContribution modelParameters, + internal FeatureContributionCalculatingEstimator(IHostEnvironment env, ICalculateFeatureContribution modelParameters, string featureColumn = DefaultColumnNames.Features, int numPositiveContributions = Defaults.NumPositiveContributions, int numNegativeContributions = Defaults.NumNegativeContributions, @@ -312,6 +312,10 @@ public FeatureContributionCalculatingEstimator(IHostEnvironment env, ICalculateF _predictor = modelParameters; } + /// <summary> + /// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer. + /// Used for schema propagation and verification in a pipeline. + /// </summary> public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { // Check that the featureColumn is present. @@ -341,20 +345,20 @@ internal static class FeatureContributionEntryPoint [TlcModule.EntryPoint(Name = "Transforms.FeatureContributionCalculationTransformer", Desc = FeatureContributionCalculatingTransformer.Summary, UserName = FeatureContributionCalculatingTransformer.FriendlyName)] - public static CommonOutputs.TransformOutput FeatureContributionCalculation(IHostEnvironment env, FeatureContributionCalculatingTransformer.Arguments args) + public static CommonOutputs.TransformOutput FeatureContributionCalculation(IHostEnvironment env, FeatureContributionCalculatingTransformer.Options options) { Contracts.CheckValue(env, nameof(env)); var host = env.Register(nameof(FeatureContributionCalculatingTransformer)); - host.CheckValue(args, nameof(args)); - EntryPointUtils.CheckInputArgs(host, args); - host.CheckValue(args.PredictorModel, nameof(args.PredictorModel)); + host.CheckValue(options, nameof(options)); + EntryPointUtils.CheckInputArgs(host, options); + host.CheckValue(options.PredictorModel, nameof(options.PredictorModel)); - var predictor = args.PredictorModel.Predictor as ICalculateFeatureContribution; + var predictor = options.PredictorModel.Predictor as ICalculateFeatureContribution; if (predictor == null) throw host.ExceptUserArg(nameof(predictor), "The provided model parameters do not support feature contribution calculation."); - var outData = new FeatureContributionCalculatingTransformer(host, predictor, args.FeatureColumn, args.Top, args.Bottom, args.Normalize).Transform(args.Data); + var outData = new FeatureContributionCalculatingTransformer(host, predictor, options.FeatureColumn, options.Top, options.Bottom, options.Normalize).Transform(options.Data); - return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, outData, args.Data), OutputData = outData}; + return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, outData, options.Data), OutputData = outData}; } } } diff --git a/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs b/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs similarity index 98% rename from src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs rename to src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs index 631ed517e1..9b57d6d031 100644 --- a/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs @@ -16,7 +16,7 @@ using Microsoft.ML.Model; using Microsoft.ML.Transforms.FeatureSelection; -[assembly: LoadableClass(SlotsDroppingTransformer.Summary, typeof(IDataTransform), typeof(SlotsDroppingTransformer), typeof(SlotsDroppingTransformer.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(SlotsDroppingTransformer.Summary, typeof(IDataTransform), typeof(SlotsDroppingTransformer), typeof(SlotsDroppingTransformer.Options), typeof(SignatureDataTransform), SlotsDroppingTransformer.FriendlyName, SlotsDroppingTransformer.LoaderSignature, "DropSlots")] [assembly: LoadableClass(SlotsDroppingTransformer.Summary, typeof(IDataTransform), typeof(SlotsDroppingTransformer), null, typeof(SignatureLoadDataTransform), @@ -37,14 +37,15 @@ namespace Microsoft.ML.Transforms.FeatureSelection /// </summary> public sealed class SlotsDroppingTransformer : OneToOneTransformerBase { - public sealed class Arguments + internal sealed class Options { [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "Columns to drop the slots for", Name = "Column", ShortName = "col", SortOrder = 1)] public Column[] Columns; } - public sealed class Column : OneToOneColumn + [BestFriend] + internal sealed class Column : OneToOneColumn { [Argument(ArgumentType.Multiple, HelpText = "Source slot index range(s) of the column to drop")] public Range[] Slots; @@ -112,7 +113,7 @@ internal bool TryUnparse(StringBuilder sb) } } - public sealed class Range + internal sealed class Range { [Argument(ArgumentType.Required, HelpText = "First index in the range")] public int Min; @@ -191,7 +192,8 @@ public bool IsValid() /// <summary> /// Describes how the transformer handles one input-output column pair. /// </summary> - public sealed class ColumnInfo + [BestFriend] + internal sealed class ColumnInfo { public readonly string Name; public readonly string InputColumnName; @@ -258,7 +260,7 @@ private static VersionInfo GetVersionInfo() /// <param name="inputColumnName">Name of column to transform. If set to <see langword="null"/>, the value of the <paramref name="outputColumnName"/> will be used as source.</param> /// <param name="min">Specifies the lower bound of the range of slots to be dropped. The lower bound is inclusive. </param> /// <param name="max">Specifies the upper bound of the range of slots to be dropped. The upper bound is exclusive.</param> - public SlotsDroppingTransformer(IHostEnvironment env, string outputColumnName, string inputColumnName = null, int min = default, int? max = null) + internal SlotsDroppingTransformer(IHostEnvironment env, string outputColumnName, string inputColumnName = null, int min = default, int? max = null) : this(env, new ColumnInfo(outputColumnName, inputColumnName, (min, max))) { } @@ -268,7 +270,7 @@ public SlotsDroppingTransformer(IHostEnvironment env, string outputColumnName, s /// </summary> /// <param name="env">The environment to use.</param> /// <param name="columns">Specifies the ranges of slots to drop for each column pair.</param> - public SlotsDroppingTransformer(IHostEnvironment env, params ColumnInfo[] columns) + internal SlotsDroppingTransformer(IHostEnvironment env, params ColumnInfo[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns)) { Host.AssertNonEmpty(ColumnPairs); @@ -308,9 +310,9 @@ private static SlotsDroppingTransformer Create(IHostEnvironment env, ModelLoadCo } // Factory method for SignatureDataTransform. - private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + private static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) { - var columns = args.Columns.Select(column => new ColumnInfo(column)).ToArray(); + var columns = options.Columns.Select(column => new ColumnInfo(column)).ToArray(); return new SlotsDroppingTransformer(env, columns).MakeDataTransform(input); } diff --git a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs index 099518be67..5bdbdb0a29 100644 --- a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs +++ b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs @@ -17,7 +17,7 @@ using Microsoft.ML.Model.Onnx; using Microsoft.ML.Transforms.Conversions; -[assembly: LoadableClass(TypeConvertingTransformer.Summary, typeof(IDataTransform), typeof(TypeConvertingTransformer), typeof(TypeConvertingTransformer.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(TypeConvertingTransformer.Summary, typeof(IDataTransform), typeof(TypeConvertingTransformer), typeof(TypeConvertingTransformer.Options), typeof(SignatureDataTransform), TypeConvertingTransformer.UserName, TypeConvertingTransformer.ShortName, "ConvertTransform", DocName = "transform/ConvertTransform.md")] [assembly: LoadableClass(TypeConvertingTransformer.Summary, typeof(IDataTransform), typeof(TypeConvertingTransformer), null, typeof(SignatureLoadDataTransform), @@ -36,7 +36,7 @@ namespace Microsoft.ML.Transforms.Conversions internal static class TypeConversion { [TlcModule.EntryPoint(Name = "Transforms.ColumnTypeConverter", Desc = TypeConvertingTransformer.Summary, UserName = TypeConvertingTransformer.UserName, ShortName = TypeConvertingTransformer.ShortName)] - public static CommonOutputs.TransformOutput Convert(IHostEnvironment env, TypeConvertingTransformer.Arguments input) + public static CommonOutputs.TransformOutput Convert(IHostEnvironment env, TypeConvertingTransformer.Options input) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(input, nameof(input)); @@ -53,11 +53,13 @@ public static CommonOutputs.TransformOutput Convert(IHostEnvironment env, TypeCo } /// <summary> - /// ConvertTransform allow to change underlying column type as long as we know how to convert types. + /// <see cref="TypeConvertingTransformer"/> converts underlying column types. + /// The source and destination column types need to be compatible. /// </summary> public sealed class TypeConvertingTransformer : OneToOneTransformerBase { - public class Column : OneToOneColumn + [BestFriend] + internal class Column : OneToOneColumn { [Argument(ArgumentType.AtMostOnce, HelpText = "The result type", ShortName = "type")] public DataKind? ResultType; @@ -127,7 +129,8 @@ internal bool TryUnparse(StringBuilder sb) } } - public class Arguments : TransformInputBase + [BestFriend] + internal class Options : TransformInputBase { [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:type:src)", Name = "Column", ShortName = "col", SortOrder = 1)] public Column[] Columns; @@ -168,37 +171,14 @@ private static VersionInfo GetVersionInfo() private const string RegistrationName = "Convert"; - public IReadOnlyCollection<ColumnInfo> Columns => _columns.AsReadOnly(); - /// <summary> - /// Describes how the transformer handles one column pair. + /// A collection of <see cref="TypeConvertingEstimator.ColumnInfo"/> describing the settings of the transformation. /// </summary> - public sealed class ColumnInfo - { - public readonly string Name; - public readonly string InputColumnName; - public readonly DataKind OutputKind; - public readonly KeyCount OutputKeyCount; + public IReadOnlyCollection<TypeConvertingEstimator.ColumnInfo> Columns => _columns.AsReadOnly(); - /// <summary> - /// Describes how the transformer handles one column pair. - /// </summary> - /// <param name="name">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param> - /// <param name="outputKind">The expected kind of the converted column.</param> - /// <param name="inputColumnName">Name of column to transform. If set to <see langword="null"/>, the value of the <paramref name="name"/> will be used as source.</param> - /// <param name="outputKeyCount">New key range, if we work with key type.</param> - public ColumnInfo(string name, DataKind outputKind, string inputColumnName, KeyCount outputKeyCount = null) - { - Name = name; - InputColumnName = inputColumnName ?? name; - OutputKind = outputKind; - OutputKeyCount = outputKeyCount; - } - } - - private readonly ColumnInfo[] _columns; + private readonly TypeConvertingEstimator.ColumnInfo[] _columns; - private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(ColumnInfo[] columns) + private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(TypeConvertingEstimator.ColumnInfo[] columns) { Contracts.CheckNonEmpty(columns, nameof(columns)); return columns.Select(x => (x.Name, x.InputColumnName)).ToArray(); @@ -212,15 +192,15 @@ private static (string outputColumnName, string inputColumnName)[] GetColumnPair /// <param name="inputColumnName">Name of the column to be transformed. If this is null '<paramref name="outputColumnName"/>' will be used.</param> /// <param name="outputKind">The expected type of the converted column.</param> /// <param name="outputKeyCount">New key count if we work with key type.</param> - public TypeConvertingTransformer(IHostEnvironment env, string outputColumnName, DataKind outputKind, string inputColumnName = null, KeyCount outputKeyCount = null) - : this(env, new ColumnInfo(outputColumnName, outputKind, inputColumnName ?? outputColumnName, outputKeyCount)) + internal TypeConvertingTransformer(IHostEnvironment env, string outputColumnName, DataKind outputKind, string inputColumnName = null, KeyCount outputKeyCount = null) + : this(env, new TypeConvertingEstimator.ColumnInfo(outputColumnName, outputKind, inputColumnName ?? outputColumnName, outputKeyCount)) { } /// <summary> /// Create a <see cref="TypeConvertingTransformer"/> that takes multiple pairs of columns. /// </summary> - public TypeConvertingTransformer(IHostEnvironment env, params ColumnInfo[] columns) + internal TypeConvertingTransformer(IHostEnvironment env, params TypeConvertingEstimator.ColumnInfo[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TypeConvertingTransformer)), GetColumnPairs(columns)) { _columns = columns.ToArray(); @@ -281,7 +261,7 @@ private TypeConvertingTransformer(IHost host, ModelLoadContext ctx) // if there is a keyCount // ulong: keyCount (0 for unspecified) - _columns = new ColumnInfo[columnsLength]; + _columns = new TypeConvertingEstimator.ColumnInfo[columnsLength]; for (int i = 0; i < columnsLength; i++) { byte b = ctx.Reader.ReadByte(); @@ -310,23 +290,23 @@ private TypeConvertingTransformer(IHost host, ModelLoadContext ctx) keyCount = new KeyCount(count); } - _columns[i] = new ColumnInfo(ColumnPairs[i].outputColumnName, kind, ColumnPairs[i].inputColumnName, keyCount); + _columns[i] = new TypeConvertingEstimator.ColumnInfo(ColumnPairs[i].outputColumnName, kind, ColumnPairs[i].inputColumnName, keyCount); } } // Factory method for SignatureDataTransform. - internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) { Contracts.CheckValue(env, nameof(env)); - env.CheckValue(args, nameof(args)); + env.CheckValue(options, nameof(options)); env.CheckValue(input, nameof(input)); - env.CheckValue(args.Columns, nameof(args.Columns)); - var cols = new ColumnInfo[args.Columns.Length]; + env.CheckValue(options.Columns, nameof(options.Columns)); + var cols = new TypeConvertingEstimator.ColumnInfo[options.Columns.Length]; for (int i = 0; i < cols.Length; i++) { - var item = args.Columns[i]; - var tempResultType = item.ResultType ?? args.ResultType; + var item = options.Columns[i]; + var tempResultType = item.ResultType ?? options.ResultType; KeyCount keyCount = null; // If KeyCount or Range are defined on this column, set keyCount to the appropriate value. if (item.KeyCount != null) @@ -337,10 +317,10 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat // defined in the Arguments object only in case the ResultType is not defined on the column. else if (item.ResultType == null) { - if (args.KeyCount != null) - keyCount = args.KeyCount; - else if (args.Range != null) - keyCount = KeyCount.Parse(args.Range); + if (options.KeyCount != null) + keyCount = options.KeyCount; + else if (options.Range != null) + keyCount = KeyCount.Parse(options.Range); } DataKind kind; @@ -358,7 +338,7 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat { kind = tempResultType.Value; } - cols[i] = new ColumnInfo(item.Name, kind, item.Source ?? item.Name, keyCount); + cols[i] = new TypeConvertingEstimator.ColumnInfo(item.Name, kind, item.Source ?? item.Name, keyCount); }; return new TypeConvertingTransformer(env, cols).MakeDataTransform(input); } @@ -534,7 +514,8 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, } /// <summary> - /// Convert estimator allow you take column and change it type as long as we know how to do conversion between types. + /// <see cref="TypeConvertingEstimator"/> converts underlying column types. + /// The source and destination column types need to be compatible. /// </summary> public sealed class TypeConvertingEstimator : TrivialEstimator<TypeConvertingTransformer> { @@ -543,6 +524,60 @@ internal sealed class Defaults public const DataKind DefaultOutputKind = DataKind.R4; } + /// <summary> + /// Describes how the transformer handles one column pair. + /// </summary> + public sealed class ColumnInfo + { + /// <summary> + /// Name of the column resulting from the transformation of <see cref="InputColumnName"/>. + /// </summary> + public readonly string Name; + /// <summary> + /// Name of column to transform. If set to <see langword="null"/>, the value of the <see cref="Name"/> will be used as source. + /// </summary> + public readonly string InputColumnName; + /// <summary> + /// The expected kind of the converted column. + /// </summary> + public readonly DataKind OutputKind; + /// <summary> + /// New key count, if we work with key type. + /// </summary> + public readonly KeyCount OutputKeyCount; + + /// <summary> + /// Describes how the transformer handles one column pair. + /// </summary> + /// <param name="name">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param> + /// <param name="outputKind">The expected kind of the converted column.</param> + /// <param name="inputColumnName">Name of column to transform. If set to <see langword="null"/>, the value of the <paramref name="name"/> will be used as source.</param> + /// <param name="outputKeyCount">New key count, if we work with key type.</param> + public ColumnInfo(string name, DataKind outputKind, string inputColumnName, KeyCount outputKeyCount = null) + { + Name = name; + InputColumnName = inputColumnName ?? name; + OutputKind = outputKind; + OutputKeyCount = outputKeyCount; + } + + /// <summary> + /// Describes how the transformer handles one column pair. + /// </summary> + /// <param name="name">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param> + /// <param name="type">The expected kind of the converted column.</param> + /// <param name="inputColumnName">Name of column to transform. If set to <see langword="null"/>, the value of the <paramref name="name"/> will be used as source.</param> + /// <param name="outputKeyCount">New key count, if we work with key type.</param> + public ColumnInfo(string name, Type type, string inputColumnName, KeyCount outputKeyCount = null) + { + Name = name; + InputColumnName = inputColumnName ?? name; + if (!type.TryGetDataKind(out OutputKind)) + throw Contracts.ExceptUserArg(nameof(type), $"Unsupported type {type}."); + OutputKeyCount = outputKeyCount; + } + } + /// <summary> /// Convinence constructor for simple one column case. /// </summary> @@ -550,21 +585,25 @@ internal sealed class Defaults /// <param name="outputColumnName">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param> /// <param name="inputColumnName">Name of the column to transform. If set to <see langword="null"/>, the value of the <paramref name="outputColumnName"/> will be used as source.</param> /// <param name="outputKind">The expected type of the converted column.</param> - public TypeConvertingEstimator(IHostEnvironment env, + internal TypeConvertingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, DataKind outputKind = Defaults.DefaultOutputKind) - : this(env, new TypeConvertingTransformer.ColumnInfo(outputColumnName, outputKind, inputColumnName ?? outputColumnName)) + : this(env, new ColumnInfo(outputColumnName, outputKind, inputColumnName ?? outputColumnName)) { } /// <summary> /// Create a <see cref="TypeConvertingEstimator"/> that takes multiple pairs of columns. /// </summary> - public TypeConvertingEstimator(IHostEnvironment env, params TypeConvertingTransformer.ColumnInfo[] columns) : + internal TypeConvertingEstimator(IHostEnvironment env, params ColumnInfo[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TypeConvertingEstimator)), new TypeConvertingTransformer(env, columns)) { } + /// <summary> + /// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer. + /// Used for schema propagation and verification in a pipeline. + /// </summary> public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); diff --git a/src/Microsoft.ML.EntryPoints/FeatureCombiner.cs b/src/Microsoft.ML.EntryPoints/FeatureCombiner.cs index fd4dba2aba..0ab5651a89 100644 --- a/src/Microsoft.ML.EntryPoints/FeatureCombiner.cs +++ b/src/Microsoft.ML.EntryPoints/FeatureCombiner.cs @@ -58,7 +58,7 @@ public static CommonOutputs.TransformOutput PrepareFeatures(IHostEnvironment env throw ch.Except("No feature columns specified"); var featNames = new HashSet<string>(); var concatNames = new List<KeyValuePair<string, string>>(); - List<TypeConvertingTransformer.ColumnInfo> cvt; + List<TypeConvertingEstimator.ColumnInfo> cvt; int errCount; var ktv = ConvertFeatures(feats.ToArray(), featNames, concatNames, ch, out cvt, out errCount); Contracts.Assert(featNames.Count > 0); @@ -139,7 +139,7 @@ private static string GetTerms(IDataView data, string colName) return sb.ToString(); } - private static IDataView ApplyConvert(List<TypeConvertingTransformer.ColumnInfo> cvt, IDataView viewTrain, IHostEnvironment env) + private static IDataView ApplyConvert(List<TypeConvertingEstimator.ColumnInfo> cvt, IDataView viewTrain, IHostEnvironment env) { Contracts.AssertValueOrNull(cvt); Contracts.AssertValue(viewTrain); @@ -150,7 +150,7 @@ private static IDataView ApplyConvert(List<TypeConvertingTransformer.ColumnInfo> } private static List<KeyToVectorMappingEstimator.ColumnInfo> ConvertFeatures(IEnumerable<Schema.Column> feats, HashSet<string> featNames, List<KeyValuePair<string, string>> concatNames, IChannel ch, - out List<TypeConvertingTransformer.ColumnInfo> cvt, out int errCount) + out List<TypeConvertingEstimator.ColumnInfo> cvt, out int errCount) { Contracts.AssertValue(feats); Contracts.AssertValue(featNames); @@ -185,7 +185,7 @@ private static IDataView ApplyConvert(List<TypeConvertingTransformer.ColumnInfo> // This happens when the training is done on an XDF and the scoring is done on a data frame. var colName = GetUniqueName(); concatNames.Add(new KeyValuePair<string, string>(col.Name, colName)); - Utils.Add(ref cvt, new TypeConvertingTransformer.ColumnInfo(colName, DataKind.R4, col.Name)); + Utils.Add(ref cvt, new TypeConvertingEstimator.ColumnInfo(colName, DataKind.R4, col.Name)); continue; } } @@ -300,19 +300,7 @@ public static CommonOutputs.TransformOutput PrepareRegressionLabel(IHostEnvironm return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, nop, input.Data), OutputData = nop }; } - var args = new TypeConvertingTransformer.Arguments() - { - Columns = new[] - { - new TypeConvertingTransformer.Column() - { - Name = input.LabelColumn, - Source = input.LabelColumn, - ResultType = DataKind.R4 - } - } - }; - var xf = new TypeConvertingTransformer(host, new TypeConvertingTransformer.ColumnInfo(input.LabelColumn, DataKind.R4, input.LabelColumn)).Transform(input.Data); + var xf = new TypeConvertingTransformer(host, new TypeConvertingEstimator.ColumnInfo(input.LabelColumn, DataKind.R4, input.LabelColumn)).Transform(input.Data); return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf }; } } diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index eb840cc068..fc94e6099c 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -1383,7 +1383,7 @@ private Dataset Construct(RoleMappedData examples, ref int numExamples, int maxB } // Convert the group column, if one exists. if (examples.Schema.Group?.Name is string groupName) - data = new TypeConvertingTransformer(Host, new TypeConvertingTransformer.ColumnInfo(groupName, DataKind.U8, groupName)).Transform(data); + data = new TypeConvertingTransformer(Host, new TypeConvertingEstimator.ColumnInfo(groupName, DataKind.U8, groupName)).Transform(data); // Since we've passed it through a few transforms, reconstitute the mapping on the // newly transformed data. diff --git a/src/Microsoft.ML.StaticPipe/TransformsStatic.cs b/src/Microsoft.ML.StaticPipe/TransformsStatic.cs index b1c15d3633..4a30931b13 100644 --- a/src/Microsoft.ML.StaticPipe/TransformsStatic.cs +++ b/src/Microsoft.ML.StaticPipe/TransformsStatic.cs @@ -933,11 +933,11 @@ private sealed class Rec : EstimatorReconciler public override IEstimator<ITransformer> Reconcile(IHostEnvironment env, PipelineColumn[] toOutput, IReadOnlyDictionary<PipelineColumn, string> inputNames, IReadOnlyDictionary<PipelineColumn, string> outputNames, IReadOnlyCollection<string> usedNames) { - var infos = new TypeConvertingTransformer.ColumnInfo[toOutput.Length]; + var infos = new TypeConvertingEstimator.ColumnInfo[toOutput.Length]; for (int i = 0; i < toOutput.Length; ++i) { var tcol = (IConvertCol)toOutput[i]; - infos[i] = new TypeConvertingTransformer.ColumnInfo(outputNames[toOutput[i]], tcol.Kind, inputNames[tcol.Input]); + infos[i] = new TypeConvertingEstimator.ColumnInfo(outputNames[toOutput[i]], tcol.Kind, inputNames[tcol.Input]); } return new TypeConvertingEstimator(env, infos); } diff --git a/src/Microsoft.ML.Transforms/CountFeatureSelection.cs b/src/Microsoft.ML.Transforms/CountFeatureSelection.cs index f8367c1e3f..d46a8d1d08 100644 --- a/src/Microsoft.ML.Transforms/CountFeatureSelection.cs +++ b/src/Microsoft.ML.Transforms/CountFeatureSelection.cs @@ -15,7 +15,7 @@ using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Transforms.FeatureSelection; -[assembly: LoadableClass(CountFeatureSelectingEstimator.Summary, typeof(IDataTransform), typeof(CountFeatureSelectingEstimator), typeof(CountFeatureSelectingEstimator.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(CountFeatureSelectingEstimator.Summary, typeof(IDataTransform), typeof(CountFeatureSelectingEstimator), typeof(CountFeatureSelectingEstimator.Options), typeof(SignatureDataTransform), CountFeatureSelectingEstimator.UserName, "CountFeatureSelectionTransform", "CountFeatureSelection")] namespace Microsoft.ML.Transforms.FeatureSelection @@ -35,7 +35,7 @@ internal static class Defaults public const long Count = 1; } - public sealed class Arguments : TransformInputBase + internal sealed class Options : TransformInputBase { [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "Columns to use for feature selection", Name = "Column", ShortName = "col", SortOrder = 1)] public string[] Columns; @@ -46,10 +46,16 @@ public sealed class Arguments : TransformInputBase internal static string RegistrationName = "CountFeatureSelectionTransform"; + /// <summary> + /// Describes how the transformer handles one column pair. + /// </summary> public sealed class ColumnInfo { + /// <summary> Name of the column resulting from the transformation of <see cref="InputColumnName"/>.</summary> public readonly string Name; + /// <summary> Name of the column to transform.</summary> public readonly string InputColumnName; + /// <summary> If the count of non-default values for a slot is greater than or equal to this threshold in the training data, the slot is preserved.</summary> public readonly long MinCount; /// <summary> @@ -79,7 +85,7 @@ public ColumnInfo(string name, string inputColumnName = null, long minCount = De /// ]]> /// </format> /// </example> - public CountFeatureSelectingEstimator(IHostEnvironment env, params ColumnInfo[] columns) + internal CountFeatureSelectingEstimator(IHostEnvironment env, params ColumnInfo[] columns) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(RegistrationName); @@ -100,11 +106,15 @@ public CountFeatureSelectingEstimator(IHostEnvironment env, params ColumnInfo[] /// ]]> /// </format> /// </example> - public CountFeatureSelectingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, long minCount = Defaults.Count) + internal CountFeatureSelectingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, long minCount = Defaults.Count) : this(env, new ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName, minCount)) { } + /// <summary> + /// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer. + /// Used for schema propagation and verification in a pipeline. + /// </summary> public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); @@ -126,6 +136,9 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) return new SchemaShape(result.Values); } + /// <summary> + /// Trains and returns a <see cref="ITransformer"/>. + /// </summary> public ITransformer Fit(IDataView input) { _host.CheckValue(input, nameof(input)); @@ -162,16 +175,16 @@ public ITransformer Fit(IDataView input) /// <summary> /// Create method corresponding to SignatureDataTransform. /// </summary> - internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register(RegistrationName); - host.CheckValue(args, nameof(args)); + host.CheckValue(options, nameof(options)); host.CheckValue(input, nameof(input)); - host.CheckUserArg(Utils.Size(args.Columns) > 0, nameof(args.Columns)); - host.CheckUserArg(args.Count > 0, nameof(args.Count)); + host.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns)); + host.CheckUserArg(options.Count > 0, nameof(options.Count)); - var columnInfos = args.Columns.Select(inColName => new ColumnInfo(inColName, minCount: args.Count)).ToArray(); + var columnInfos = options.Columns.Select(inColName => new ColumnInfo(inColName, minCount: options.Count)).ToArray(); return new CountFeatureSelectingEstimator(env, columnInfos).Fit(input).Transform(input) as IDataTransform; } @@ -245,11 +258,11 @@ public static long[][] Train(IHostEnvironment env, IDataView input, string[] col int colSrc; var colName = columns[i]; if (!schema.TryGetColumnIndex(colName, out colSrc)) - throw env.ExceptUserArg(nameof(CountFeatureSelectingEstimator.Arguments.Columns), "Source column '{0}' not found", colName); + throw env.ExceptUserArg(nameof(CountFeatureSelectingEstimator.Options.Columns), "Source column '{0}' not found", colName); var colType = schema[colSrc].Type; if (colType is VectorType vectorType && !vectorType.IsKnownSize) - throw env.ExceptUserArg(nameof(CountFeatureSelectingEstimator.Arguments.Columns), "Variable length column '{0}' is not allowed", colName); + throw env.ExceptUserArg(nameof(CountFeatureSelectingEstimator.Options.Columns), "Variable length column '{0}' is not allowed", colName); activeCols.Add(schema[colSrc]); colSrcs[i] = colSrc; diff --git a/src/Microsoft.ML.Transforms/EntryPoints/SelectFeatures.cs b/src/Microsoft.ML.Transforms/EntryPoints/SelectFeatures.cs index 199e52d5bf..7ca6a4b4fe 100644 --- a/src/Microsoft.ML.Transforms/EntryPoints/SelectFeatures.cs +++ b/src/Microsoft.ML.Transforms/EntryPoints/SelectFeatures.cs @@ -16,7 +16,7 @@ internal static class SelectFeatures [TlcModule.EntryPoint(Name = "Transforms.FeatureSelectorByCount", Desc = CountFeatureSelectingEstimator.Summary, UserName = CountFeatureSelectingEstimator.UserName)] - public static CommonOutputs.TransformOutput CountSelect(IHostEnvironment env, CountFeatureSelectingEstimator.Arguments input) + public static CommonOutputs.TransformOutput CountSelect(IHostEnvironment env, CountFeatureSelectingEstimator.Options input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("CountSelect"); @@ -31,7 +31,7 @@ public static CommonOutputs.TransformOutput CountSelect(IHostEnvironment env, Co Desc = MutualInformationFeatureSelectingEstimator.Summary, UserName = MutualInformationFeatureSelectingEstimator.UserName, ShortName = MutualInformationFeatureSelectingEstimator.ShortName)] - public static CommonOutputs.TransformOutput MutualInformationSelect(IHostEnvironment env, MutualInformationFeatureSelectingEstimator.Arguments input) + public static CommonOutputs.TransformOutput MutualInformationSelect(IHostEnvironment env, MutualInformationFeatureSelectingEstimator.Options input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("MutualInformationSelect"); diff --git a/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs index 72e40c934f..a5813aa392 100644 --- a/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs +++ b/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs @@ -17,7 +17,7 @@ using Microsoft.ML.Model; using Microsoft.ML.Transforms; -[assembly: LoadableClass(MissingValueDroppingTransformer.Summary, typeof(IDataTransform), typeof(MissingValueDroppingTransformer), typeof(MissingValueDroppingTransformer.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(MissingValueDroppingTransformer.Summary, typeof(IDataTransform), typeof(MissingValueDroppingTransformer), typeof(MissingValueDroppingTransformer.Options), typeof(SignatureDataTransform), MissingValueDroppingTransformer.FriendlyName, MissingValueDroppingTransformer.ShortName, "NADropTransform")] [assembly: LoadableClass(MissingValueDroppingTransformer.Summary, typeof(IDataTransform), typeof(MissingValueDroppingTransformer), null, typeof(SignatureLoadDataTransform), @@ -34,13 +34,13 @@ namespace Microsoft.ML.Transforms /// <include file='doc.xml' path='doc/members/member[@name="NADrop"]'/> public sealed class MissingValueDroppingTransformer : OneToOneTransformerBase { - public sealed class Arguments : TransformInputBase + internal sealed class Options : TransformInputBase { [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "Columns to drop the NAs for", Name = "Column", ShortName = "col", SortOrder = 1)] public Column[] Columns; } - public sealed class Column : OneToOneColumn + internal sealed class Column : OneToOneColumn { internal static Column Parse(string str) { @@ -75,6 +75,9 @@ private static VersionInfo GetVersionInfo() private const string RegistrationName = "DropNAs"; + /// <summary> + /// The names of the input columns of the transformation and the corresponding names for the output columns. + /// </summary> public IReadOnlyList<(string outputColumnName, string inputColumnName)> Columns => ColumnPairs.AsReadOnly(); /// <summary> @@ -82,13 +85,13 @@ private static VersionInfo GetVersionInfo() /// </summary> /// <param name="env">The environment to use.</param> /// <param name="columns">The names of the input columns of the transformation and the corresponding names for the output columns.</param> - public MissingValueDroppingTransformer(IHostEnvironment env, params (string outputColumnName, string inputColumnName)[] columns) + internal MissingValueDroppingTransformer(IHostEnvironment env, params (string outputColumnName, string inputColumnName)[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MissingValueDroppingTransformer)), columns) { } - internal MissingValueDroppingTransformer(IHostEnvironment env, Arguments args) - : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MissingValueDroppingTransformer)), GetColumnPairs(args.Columns)) + internal MissingValueDroppingTransformer(IHostEnvironment env, Options options) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MissingValueDroppingTransformer)), GetColumnPairs(options.Columns)) { } @@ -118,8 +121,8 @@ private static MissingValueDroppingTransformer Create(IHostEnvironment env, Mode } // Factory method for SignatureDataTransform. - internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) - => new MissingValueDroppingTransformer(env, args).MakeDataTransform(input); + internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) + => new MissingValueDroppingTransformer(env, options).MakeDataTransform(input); // Factory method for SignatureLoadDataTransform. private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) @@ -348,7 +351,7 @@ public sealed class MissingValueDroppingEstimator : TrivialEstimator<MissingValu /// </summary> /// <param name="env">The environment to use.</param> /// <param name="columns">The names of the input columns of the transformation and the corresponding names for the output columns.</param> - public MissingValueDroppingEstimator(IHostEnvironment env, params (string outputColumnName, string inputColumnName)[] columns) + internal MissingValueDroppingEstimator(IHostEnvironment env, params (string outputColumnName, string inputColumnName)[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MissingValueDroppingEstimator)), new MissingValueDroppingTransformer(env, columns)) { Contracts.CheckValue(env, nameof(env)); @@ -360,13 +363,14 @@ public MissingValueDroppingEstimator(IHostEnvironment env, params (string output /// <param name="env">The environment to use.</param> /// <param name="outputColumnName">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param> /// <param name="inputColumnName">Name of the column to transform. If set to <see langword="null"/>, the value of the <paramref name="outputColumnName"/> will be used as source.</param> - public MissingValueDroppingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null) + internal MissingValueDroppingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null) : this(env, (outputColumnName, inputColumnName ?? outputColumnName)) { } /// <summary> - /// Returns the schema that would be produced by the transformation. + /// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer. + /// Used for schema propagation and verification in a pipeline. /// </summary> public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { diff --git a/src/Microsoft.ML.Transforms/MissingValueHandlingTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueHandlingTransformer.cs index 9f2e780fc0..60c938f8ec 100644 --- a/src/Microsoft.ML.Transforms/MissingValueHandlingTransformer.cs +++ b/src/Microsoft.ML.Transforms/MissingValueHandlingTransformer.cs @@ -142,7 +142,7 @@ internal static IDataTransform Create(IHostEnvironment env, Options args, IDataV var replaceCols = new List<MissingValueReplacingEstimator.ColumnInfo>(); var naIndicatorCols = new List<MissingValueIndicatorTransformer.Column>(); - var naConvCols = new List<TypeConvertingTransformer.ColumnInfo>(); + var naConvCols = new List<TypeConvertingEstimator.ColumnInfo>(); var concatCols = new List<ColumnConcatenatingTransformer.TaggedColumn>(); var dropCols = new List<string>(); var tmpIsMissingColNames = input.Schema.GetTempColumnNames(args.Columns.Length, "IsMissing"); @@ -185,7 +185,7 @@ internal static IDataTransform Create(IHostEnvironment env, Options args, IDataV { throw h.Except("Cannot get a DataKind for type '{0}'", replaceItemType.RawType); } - naConvCols.Add(new TypeConvertingTransformer.ColumnInfo(tmpIsMissingColName, replaceItemTypeKind, tmpIsMissingColName)); + naConvCols.Add(new TypeConvertingEstimator.ColumnInfo(tmpIsMissingColName, replaceItemTypeKind, tmpIsMissingColName)); } // Add the NAReplaceTransform column. diff --git a/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs b/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs index 95c03f3cea..5b10c08ef6 100644 --- a/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs +++ b/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs @@ -16,7 +16,7 @@ using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Transforms.FeatureSelection; -[assembly: LoadableClass(MutualInformationFeatureSelectingEstimator.Summary, typeof(IDataTransform), typeof(MutualInformationFeatureSelectingEstimator), typeof(MutualInformationFeatureSelectingEstimator.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(MutualInformationFeatureSelectingEstimator.Summary, typeof(IDataTransform), typeof(MutualInformationFeatureSelectingEstimator), typeof(MutualInformationFeatureSelectingEstimator.Options), typeof(SignatureDataTransform), MutualInformationFeatureSelectingEstimator.UserName, "MutualInformationFeatureSelection", "MutualInformationFeatureSelectionTransform", MutualInformationFeatureSelectingEstimator.ShortName)] namespace Microsoft.ML.Transforms.FeatureSelection @@ -31,14 +31,15 @@ public sealed class MutualInformationFeatureSelectingEstimator : IEstimator<ITra internal const string ShortName = "MIFeatureSelection"; internal static string RegistrationName = "MutualInformationFeatureSelectionTransform"; - public static class Defaults + [BestFriend] + internal static class Defaults { public const string LabelColumn = DefaultColumnNames.Label; public const int SlotsInOutput = 1000; public const int NumBins = 256; } - public sealed class Arguments : TransformInputBase + internal sealed class Options : TransformInputBase { [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "Columns to use for feature selection", Name = "Column", ShortName = "col", SortOrder = 1)] public string[] Columns; @@ -75,7 +76,7 @@ public sealed class Arguments : TransformInputBase /// ]]> /// </format> /// </example> - public MutualInformationFeatureSelectingEstimator(IHostEnvironment env, + internal MutualInformationFeatureSelectingEstimator(IHostEnvironment env, string labelColumn = Defaults.LabelColumn, int slotsInOutput = Defaults.SlotsInOutput, int numBins = Defaults.NumBins, @@ -109,12 +110,15 @@ public MutualInformationFeatureSelectingEstimator(IHostEnvironment env, /// ]]> /// </format> /// </example> - public MutualInformationFeatureSelectingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, + internal MutualInformationFeatureSelectingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, string labelColumn = Defaults.LabelColumn, int slotsInOutput = Defaults.SlotsInOutput, int numBins = Defaults.NumBins) : this(env, labelColumn, slotsInOutput, numBins, (outputColumnName, inputColumnName ?? outputColumnName)) { } + /// <summary> + /// Trains and returns a <see cref="ITransformer"/>. + /// </summary> public ITransformer Fit(IDataView input) { _host.CheckValue(input, nameof(input)); @@ -161,6 +165,10 @@ public ITransformer Fit(IDataView input) } } + /// <summary> + /// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer. + /// Used for schema propagation and verification in a pipeline. + /// </summary> public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); @@ -186,19 +194,19 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) /// <summary> /// Create method corresponding to SignatureDataTransform. /// </summary> - internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register(RegistrationName); - host.CheckValue(args, nameof(args)); + host.CheckValue(options, nameof(options)); host.CheckValue(input, nameof(input)); - host.CheckUserArg(Utils.Size(args.Columns) > 0, nameof(args.Columns)); - host.CheckUserArg(args.SlotsInOutput > 0, nameof(args.SlotsInOutput)); - host.CheckNonWhiteSpace(args.LabelColumn, nameof(args.LabelColumn)); - host.Check(args.NumBins > 1, "numBins must be greater than 1."); + host.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns)); + host.CheckUserArg(options.SlotsInOutput > 0, nameof(options.SlotsInOutput)); + host.CheckNonWhiteSpace(options.LabelColumn, nameof(options.LabelColumn)); + host.Check(options.NumBins > 1, "numBins must be greater than 1."); - (string outputColumnName, string inputColumnName)[] cols = args.Columns.Select(col => (col, col)).ToArray(); - return new MutualInformationFeatureSelectingEstimator(env, args.LabelColumn, args.SlotsInOutput, args.NumBins, cols).Fit(input).Transform(input) as IDataTransform; + (string outputColumnName, string inputColumnName)[] cols = options.Columns.Select(col => (col, col)).ToArray(); + return new MutualInformationFeatureSelectingEstimator(env, options.LabelColumn, options.SlotsInOutput, options.NumBins, cols).Fit(input).Transform(input) as IDataTransform; } /// <summary> @@ -369,14 +377,14 @@ public float[][] GetScores(IDataView input, string labelColumnName, string[] col if (!schema.TryGetColumnIndex(labelColumnName, out int labelCol)) { - throw _host.ExceptUserArg(nameof(MutualInformationFeatureSelectingEstimator.Arguments.LabelColumn), + throw _host.ExceptUserArg(nameof(MutualInformationFeatureSelectingEstimator.Options.LabelColumn), "Label column '{0}' not found", labelColumnName); } var labelType = schema[labelCol].Type; if (!IsValidColumnType(labelType)) { - throw _host.ExceptUserArg(nameof(MutualInformationFeatureSelectingEstimator.Arguments.LabelColumn), + throw _host.ExceptUserArg(nameof(MutualInformationFeatureSelectingEstimator.Options.LabelColumn), "Label column '{0}' does not have compatible type", labelColumnName); } @@ -387,20 +395,20 @@ public float[][] GetScores(IDataView input, string labelColumnName, string[] col var colName = columns[i]; if (!schema.TryGetColumnIndex(colName, out int colSrc)) { - throw _host.ExceptUserArg(nameof(MutualInformationFeatureSelectingEstimator.Arguments.Columns), + throw _host.ExceptUserArg(nameof(MutualInformationFeatureSelectingEstimator.Options.Columns), "Source column '{0}' not found", colName); } var colType = schema[colSrc].Type; if (colType is VectorType vectorType && !vectorType.IsKnownSize) { - throw _host.ExceptUserArg(nameof(MutualInformationFeatureSelectingEstimator.Arguments.Columns), + throw _host.ExceptUserArg(nameof(MutualInformationFeatureSelectingEstimator.Options.Columns), "Variable length column '{0}' is not allowed", colName); } if (!IsValidColumnType(colType.GetItemType())) { - throw _host.ExceptUserArg(nameof(MutualInformationFeatureSelectingEstimator.Arguments.Columns), + throw _host.ExceptUserArg(nameof(MutualInformationFeatureSelectingEstimator.Options.Columns), "Column '{0}' of type '{1}' does not have compatible type.", colName, colType); } diff --git a/src/Microsoft.ML.Transforms/NAHandling.cs b/src/Microsoft.ML.Transforms/NAHandling.cs index 15f65b729e..8a569bf186 100644 --- a/src/Microsoft.ML.Transforms/NAHandling.cs +++ b/src/Microsoft.ML.Transforms/NAHandling.cs @@ -18,7 +18,7 @@ internal static class NAHandling Desc = MissingValueDroppingTransformer.Summary, UserName = MissingValueDroppingTransformer.FriendlyName, ShortName = MissingValueDroppingTransformer.ShortName)] - public static CommonOutputs.TransformOutput Drop(IHostEnvironment env, MissingValueDroppingTransformer.Arguments input) + public static CommonOutputs.TransformOutput Drop(IHostEnvironment env, MissingValueDroppingTransformer.Options input) { var h = EntryPointUtils.CheckArgsAndCreateHost(env, MissingValueDroppingTransformer.ShortName, input); var xf = MissingValueDroppingTransformer.Create(h, input, input.Data); diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv index 8c683c0576..4f8bf4b49a 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv +++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv @@ -79,7 +79,7 @@ Transforms.CharacterTokenizer Character-oriented tokenizer where text is conside Transforms.ColumnConcatenator Concatenates one or more columns of the same item type. Microsoft.ML.EntryPoints.SchemaManipulation ConcatColumns Microsoft.ML.Data.ColumnConcatenatingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.ColumnCopier Duplicates columns from the dataset Microsoft.ML.EntryPoints.SchemaManipulation CopyColumns Microsoft.ML.Transforms.ColumnCopyingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.ColumnSelector Selects a set of columns, dropping all others Microsoft.ML.EntryPoints.SchemaManipulation SelectColumns Microsoft.ML.Transforms.ColumnSelectingTransformer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput -Transforms.ColumnTypeConverter Converts a column to a different type, using standard conversions. Microsoft.ML.Transforms.Conversions.TypeConversion Convert Microsoft.ML.Transforms.Conversions.TypeConvertingTransformer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput +Transforms.ColumnTypeConverter Converts a column to a different type, using standard conversions. Microsoft.ML.Transforms.Conversions.TypeConversion Convert Microsoft.ML.Transforms.Conversions.TypeConvertingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.CombinerByContiguousGroupId Groups values of a scalar column into a vector, by a contiguous group ID Microsoft.ML.Transforms.GroupingOperations Group Microsoft.ML.Transforms.GroupTransform+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.ConditionalNormalizer Normalize the columns only if needed Microsoft.ML.Data.Normalize IfNeeded Microsoft.ML.Transforms.Normalizers.NormalizeTransform+MinMaxArguments Microsoft.ML.EntryPoints.CommonOutputs+MacroOutput`1[Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput] Transforms.DataCache Caches using the specified cache option. Microsoft.ML.EntryPoints.Cache CacheData Microsoft.ML.EntryPoints.Cache+CacheInput Microsoft.ML.EntryPoints.Cache+CacheOutput @@ -87,9 +87,9 @@ Transforms.DatasetScorer Score a dataset with a predictor model Microsoft.ML.Ent Transforms.DatasetTransformScorer Score a dataset with a transform model Microsoft.ML.EntryPoints.ScoreModel ScoreUsingTransform Microsoft.ML.EntryPoints.ScoreModel+InputTransformScorer Microsoft.ML.EntryPoints.ScoreModel+Output Transforms.Dictionarizer Converts input values (words, numbers, etc.) to index in a dictionary. Microsoft.ML.Transforms.Text.TextAnalytics TermTransform Microsoft.ML.Transforms.Conversions.ValueToKeyMappingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.FeatureCombiner Combines all the features into one feature column. Microsoft.ML.EntryPoints.FeatureCombiner PrepareFeatures Microsoft.ML.EntryPoints.FeatureCombiner+FeatureCombinerInput Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput -Transforms.FeatureContributionCalculationTransformer For each data point, calculates the contribution of individual features to the model prediction. Microsoft.ML.Data.FeatureContributionEntryPoint FeatureContributionCalculation Microsoft.ML.Data.FeatureContributionCalculatingTransformer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput -Transforms.FeatureSelectorByCount Selects the slots for which the count of non-default values is greater than or equal to a threshold. Microsoft.ML.Transforms.SelectFeatures CountSelect Microsoft.ML.Transforms.FeatureSelection.CountFeatureSelectingEstimator+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput -Transforms.FeatureSelectorByMutualInformation Selects the top k slots across all specified columns ordered by their mutual information with the label column. Microsoft.ML.Transforms.SelectFeatures MutualInformationSelect Microsoft.ML.Transforms.FeatureSelection.MutualInformationFeatureSelectingEstimator+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput +Transforms.FeatureContributionCalculationTransformer For each data point, calculates the contribution of individual features to the model prediction. Microsoft.ML.Data.FeatureContributionEntryPoint FeatureContributionCalculation Microsoft.ML.Data.FeatureContributionCalculatingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput +Transforms.FeatureSelectorByCount Selects the slots for which the count of non-default values is greater than or equal to a threshold. Microsoft.ML.Transforms.SelectFeatures CountSelect Microsoft.ML.Transforms.FeatureSelection.CountFeatureSelectingEstimator+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput +Transforms.FeatureSelectorByMutualInformation Selects the top k slots across all specified columns ordered by their mutual information with the label column. Microsoft.ML.Transforms.SelectFeatures MutualInformationSelect Microsoft.ML.Transforms.FeatureSelection.MutualInformationFeatureSelectingEstimator+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.GlobalContrastNormalizer Performs a global contrast normalization on input values: Y = (s * X - M) / D, where s is a scale, M is mean and D is either L2 norm or standard deviation. Microsoft.ML.Transforms.Projections.LpNormalization GcNormalize Microsoft.ML.Transforms.Projections.LpNormalizingTransformer+GcnOptions Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.HashConverter Converts column values into hashes. This transform accepts both numeric and text inputs, both single and vector-valued columns. Microsoft.ML.Transforms.Conversions.HashJoin Apply Microsoft.ML.Transforms.Conversions.HashJoiningTransform+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.ImageGrayscale Convert image into grayscale. Microsoft.ML.ImageAnalytics.EntryPoints.ImageAnalytics ImageGrayscale Microsoft.ML.ImageAnalytics.ImageGrayscalingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput @@ -108,7 +108,7 @@ Transforms.MeanVarianceNormalizer Normalizes the data based on the computed mean Transforms.MinMaxNormalizer Normalizes the data based on the observed minimum and maximum values of the data. Microsoft.ML.Data.Normalize MinMax Microsoft.ML.Transforms.Normalizers.NormalizeTransform+MinMaxArguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.MissingValueHandler Handle missing values by replacing them with either the default value or the mean/min/max value (for non-text columns only). An indicator column can optionally be concatenated, if theinput column type is numeric. Microsoft.ML.Transforms.NAHandling Handle Microsoft.ML.Transforms.MissingValueHandlingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.MissingValueIndicator Create a boolean output column with the same number of slots as the input column, where the output value is true if the value in the input column is missing. Microsoft.ML.Transforms.NAHandling Indicator Microsoft.ML.Transforms.MissingValueIndicatorTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput -Transforms.MissingValuesDropper Removes NAs from vector columns. Microsoft.ML.Transforms.NAHandling Drop Microsoft.ML.Transforms.MissingValueDroppingTransformer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput +Transforms.MissingValuesDropper Removes NAs from vector columns. Microsoft.ML.Transforms.NAHandling Drop Microsoft.ML.Transforms.MissingValueDroppingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.MissingValuesRowDropper Filters out rows that contain missing values. Microsoft.ML.Transforms.NAHandling Filter Microsoft.ML.Transforms.NAFilter+Arguments Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.MissingValueSubstitutor Create an output column of the same type and size of the input column, where missing values are replaced with either the default value or the mean/min/max value (for non-text columns only). Microsoft.ML.Transforms.NAHandling Replace Microsoft.ML.Transforms.MissingValueReplacingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.ModelCombiner Combines a sequence of TransformModels into a single model Microsoft.ML.EntryPoints.ModelOperations CombineTransformModels Microsoft.ML.EntryPoints.ModelOperations+CombineTransformModelsInput Microsoft.ML.EntryPoints.ModelOperations+CombineTransformModelsOutput diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs index 7541959768..dc7c08fe81 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs @@ -834,7 +834,7 @@ public void SavePipeWithKey() Check(tmp, "Parsing argsText failed!"); IDataView view2 = TextLoader.Create(Env, argsText, new MultiFileSource(dataPath)); - var argsConv = new TypeConvertingTransformer.Arguments(); + var argsConv = new TypeConvertingTransformer.Options(); tmp = CmdParser.ParseArguments(Env, " col=Label:U1[0-1]:Label" + " col=Features:U2:Features" + @@ -848,7 +848,7 @@ public void SavePipeWithKey() Check(tmp, "Parsing argsConv failed!"); view2 = TypeConvertingTransformer.Create(Env, argsConv, view2); - argsConv = new TypeConvertingTransformer.Arguments(); + argsConv = new TypeConvertingTransformer.Options(); tmp = CmdParser.ParseArguments(Env, " col=Label2:U2:Label col=Features2:Num:Features", argsConv); diff --git a/test/Microsoft.ML.Tests/Transformers/ConvertTests.cs b/test/Microsoft.ML.Tests/Transformers/ConvertTests.cs index 29c7f8bf18..274db4198d 100644 --- a/test/Microsoft.ML.Tests/Transformers/ConvertTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/ConvertTests.cs @@ -78,8 +78,8 @@ public void TestConvertWorkout() var data = new[] { new TestClass() { A = 1, B = new int[2] { 1,4 } }, new TestClass() { A = 2, B = new int[2] { 3,4 } }}; var dataView = ML.Data.ReadFromEnumerable(data); - var pipe = new TypeConvertingEstimator(Env, columns: new[] {new TypeConvertingTransformer.ColumnInfo("ConvA", DataKind.R4, "A"), - new TypeConvertingTransformer.ColumnInfo("ConvB", DataKind.R4, "B")}); + var pipe = ML.Transforms.Conversion.ConvertType(columns: new[] {new TypeConvertingEstimator.ColumnInfo("ConvA", DataKind.R4, "A"), + new TypeConvertingEstimator.ColumnInfo("ConvB", DataKind.R4, "B")}); TestEstimatorCore(pipe, dataView); var allTypesData = new[] @@ -117,19 +117,19 @@ public void TestConvertWorkout() }; var allTypesDataView = ML.Data.ReadFromEnumerable(allTypesData); - var allTypesPipe = new TypeConvertingEstimator(Env, columns: new[] { - new TypeConvertingTransformer.ColumnInfo("ConvA", DataKind.R4, "AA"), - new TypeConvertingTransformer.ColumnInfo("ConvB", DataKind.R4, "AB"), - new TypeConvertingTransformer.ColumnInfo("ConvC", DataKind.R4, "AC"), - new TypeConvertingTransformer.ColumnInfo("ConvD", DataKind.R4, "AD"), - new TypeConvertingTransformer.ColumnInfo("ConvE", DataKind.R4, "AE"), - new TypeConvertingTransformer.ColumnInfo("ConvF", DataKind.R4, "AF"), - new TypeConvertingTransformer.ColumnInfo("ConvG", DataKind.R4, "AG"), - new TypeConvertingTransformer.ColumnInfo("ConvH", DataKind.R4, "AH"), - new TypeConvertingTransformer.ColumnInfo("ConvK", DataKind.R4, "AK"), - new TypeConvertingTransformer.ColumnInfo("ConvL", DataKind.R4, "AL"), - new TypeConvertingTransformer.ColumnInfo("ConvM", DataKind.R4, "AM"), - new TypeConvertingTransformer.ColumnInfo("ConvN", DataKind.R4, "AN")} + var allTypesPipe = ML.Transforms.Conversion.ConvertType(columns: new[] { + new TypeConvertingEstimator.ColumnInfo("ConvA", DataKind.R4, "AA"), + new TypeConvertingEstimator.ColumnInfo("ConvB", DataKind.R4, "AB"), + new TypeConvertingEstimator.ColumnInfo("ConvC", DataKind.R4, "AC"), + new TypeConvertingEstimator.ColumnInfo("ConvD", DataKind.R4, "AD"), + new TypeConvertingEstimator.ColumnInfo("ConvE", DataKind.R4, "AE"), + new TypeConvertingEstimator.ColumnInfo("ConvF", DataKind.R4, "AF"), + new TypeConvertingEstimator.ColumnInfo("ConvG", DataKind.R4, "AG"), + new TypeConvertingEstimator.ColumnInfo("ConvH", DataKind.R4, "AH"), + new TypeConvertingEstimator.ColumnInfo("ConvK", DataKind.R4, "AK"), + new TypeConvertingEstimator.ColumnInfo("ConvL", DataKind.R4, "AL"), + new TypeConvertingEstimator.ColumnInfo("ConvM", DataKind.R4, "AM"), + new TypeConvertingEstimator.ColumnInfo("ConvN", DataKind.R4, "AN")} ); TestEstimatorCore(allTypesPipe, allTypesDataView); @@ -192,8 +192,8 @@ public void TestOldSavingAndLoading() var data = new[] { new TestClass() { A = 1, B = new int[2] { 1,4 } }, new TestClass() { A = 2, B = new int[2] { 3,4 } }}; var dataView = ML.Data.ReadFromEnumerable(data); - var pipe = new TypeConvertingEstimator(Env, columns: new[] {new TypeConvertingTransformer.ColumnInfo("ConvA", DataKind.R8, "A"), - new TypeConvertingTransformer.ColumnInfo("ConvB", DataKind.R8, "B")}); + var pipe = ML.Transforms.Conversion.ConvertType(columns: new[] {new TypeConvertingEstimator.ColumnInfo("ConvA", typeof(double), "A"), + new TypeConvertingEstimator.ColumnInfo("ConvB", typeof(double), "B")}); var result = pipe.Fit(dataView).Transform(dataView); var resultRoles = new RoleMappedData(result); @@ -213,9 +213,9 @@ public void TestMetadata() var pipe = ML.Transforms.Categorical.OneHotEncoding(new[] { new OneHotEncodingEstimator.ColumnInfo("CatA", "A", OneHotEncodingTransformer.OutputKind.Ind), new OneHotEncodingEstimator.ColumnInfo("CatB", "B", OneHotEncodingTransformer.OutputKind.Key) - }).Append(new TypeConvertingEstimator(Env, new[] { - new TypeConvertingTransformer.ColumnInfo("ConvA", DataKind.R8, "CatA"), - new TypeConvertingTransformer.ColumnInfo("ConvB", DataKind.U2, "CatB") + }).Append(ML.Transforms.Conversion.ConvertType(new[] { + new TypeConvertingEstimator.ColumnInfo("ConvA", DataKind.R8, "CatA"), + new TypeConvertingEstimator.ColumnInfo("ConvB", DataKind.U2, "CatB") })); var dataView = ML.Data.ReadFromEnumerable(data); dataView = pipe.Fit(dataView).Transform(dataView); @@ -249,7 +249,7 @@ public class SimpleSchemaUIntColumn public void TypeConvertKeyBackCompatTest() { // Model generated using the following command before the change removing Min and Count from KeyType. - // ML.Transforms.Conversion.ConvertType(new[] { new TypeConvertingTransformer.ColumnInfo("key", "convertedKey", + // ML.Transforms.Conversion.ConvertType(new[] { new TypeConvertingEstimator.ColumnInfo("key", "convertedKey", // DataKind.U8, new KeyCount(4)) }).Fit(dataView); var dataArray = new[] { @@ -272,7 +272,7 @@ public void TypeConvertKeyBackCompatTest() } var outDataOld = modelOld.Transform(dataView); - var modelNew = ML.Transforms.Conversion.ConvertType(new[] { new TypeConvertingTransformer.ColumnInfo("convertedKey", + var modelNew = ML.Transforms.Conversion.ConvertType(new[] { new TypeConvertingEstimator.ColumnInfo("convertedKey", DataKind.U8, "key", new KeyCount(4)) }).Fit(dataView); var outDataNew = modelNew.Transform(dataView); diff --git a/test/Microsoft.ML.Tests/Transformers/FeatureSelectionTests.cs b/test/Microsoft.ML.Tests/Transformers/FeatureSelectionTests.cs index 9a9161a443..f2411fb6a4 100644 --- a/test/Microsoft.ML.Tests/Transformers/FeatureSelectionTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/FeatureSelectionTests.cs @@ -121,8 +121,8 @@ public void CountFeatureSelectionWorkout() new CountFeatureSelectingEstimator.ColumnInfo("VecFeatureSelectMissing690", "VectorDouble", minCount: 690), new CountFeatureSelectingEstimator.ColumnInfo("VecFeatureSelectMissing100", "VectorDouble", minCount: 100) }; - var est = new CountFeatureSelectingEstimator(ML, "FeatureSelect", "VectorFloat", minCount: 1) - .Append(new CountFeatureSelectingEstimator(ML, columns)); + var est = ML.Transforms.FeatureSelection.SelectFeaturesBasedOnCount("FeatureSelect", "VectorFloat", count: 1) + .Append(ML.Transforms.FeatureSelection.SelectFeaturesBasedOnCount(columns)); TestEstimatorCore(est, data); @@ -156,7 +156,7 @@ public void TestCountSelectOldSavingAndLoading() var dataView = reader.Read(new MultiFileSource(dataPath)).AsDynamic; - var pipe = new CountFeatureSelectingEstimator(ML, "FeatureSelect", "VectorFloat", minCount: 1); + var pipe = ML.Transforms.FeatureSelection.SelectFeaturesBasedOnCount("FeatureSelect", "VectorFloat", count: 1); var result = pipe.Fit(dataView).Transform(dataView); var resultRoles = new RoleMappedData(result); @@ -182,8 +182,8 @@ public void MutualInformationSelectionWorkout() var data = reader.Read(new MultiFileSource(dataPath)).AsDynamic; - var est = new MutualInformationFeatureSelectingEstimator(ML, "FeatureSelect", "VectorFloat", slotsInOutput: 1, labelColumn: "Label") - .Append(new MutualInformationFeatureSelectingEstimator(ML, labelColumn: "Label", slotsInOutput: 2, numBins: 100, + var est = ML.Transforms.FeatureSelection.SelectFeaturesBasedOnMutualInformation("FeatureSelect", "VectorFloat", slotsInOutput: 1, labelColumn: "Label") + .Append(ML.Transforms.FeatureSelection.SelectFeaturesBasedOnMutualInformation(labelColumn: "Label", slotsInOutput: 2, numBins: 100, columns: new[] { (name: "out1", source: "VectorFloat"), (name: "out2", source: "VectorDouble") @@ -220,7 +220,7 @@ public void TestMutualInformationOldSavingAndLoading() var dataView = reader.Read(new MultiFileSource(dataPath)).AsDynamic; - var pipe = new MutualInformationFeatureSelectingEstimator(ML, "FeatureSelect", "VectorFloat", slotsInOutput: 1, labelColumn: "Label"); + var pipe = ML.Transforms.FeatureSelection.SelectFeaturesBasedOnMutualInformation("FeatureSelect", "VectorFloat", slotsInOutput: 1, labelColumn: "Label"); var result = pipe.Fit(dataView).Transform(dataView); var resultRoles = new RoleMappedData(result);