diff --git a/src/Microsoft.ML.Data/Transforms/ChooseColumnsTransform.cs b/src/Microsoft.ML.Data/Transforms/ChooseColumnsTransform.cs
index 5efc9264f1..1459f55cab 100644
--- a/src/Microsoft.ML.Data/Transforms/ChooseColumnsTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/ChooseColumnsTransform.cs
@@ -58,6 +58,20 @@ public bool TryUnparse(StringBuilder sb)
public sealed class Arguments
{
+ public Arguments()
+ {
+
+ }
+
+ internal Arguments(params string[] columns)
+ {
+ Column = new Column[columns.Length];
+ for (int i = 0; i < columns.Length; i++)
+ {
+ Column[i] = new Column() { Source = columns[i], Name = columns[i] };
+ }
+ }
+
[Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 1)]
public Column[] Column;
@@ -442,6 +456,17 @@ private static VersionInfo GetVersionInfo()
private const string RegistrationName = "ChooseColumns";
+ ///
+ /// Convenience constructor for public facing API.
+ ///
+ /// Host Environment.
+ /// Input . This is the output from previous transform or loader.
+ /// Names of the columns to choose.
+ public ChooseColumnsTransform(IHostEnvironment env, IDataView input, params string[] columns)
+ : this(env, new Arguments(columns), input)
+ {
+ }
+
///
/// Public constructor corresponding to SignatureDataTransform.
///
diff --git a/src/Microsoft.ML.Data/Transforms/ConvertTransform.cs b/src/Microsoft.ML.Data/Transforms/ConvertTransform.cs
index c37f0a6983..52005c7558 100644
--- a/src/Microsoft.ML.Data/Transforms/ConvertTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/ConvertTransform.cs
@@ -169,6 +169,23 @@ private static VersionInfo GetVersionInfo()
// This is parallel to Infos.
private readonly ColInfoEx[] _exes;
+ ///
+ /// Convenience constructor for public facing API.
+ ///
+ /// Host Environment.
+ /// Input . This is the output from previous transform or loader.
+ /// The expected type of the converted column.
+ /// Name of the output column.
+ /// Name of the column to be converted. If this is null '' will be used.
+ public ConvertTransform(IHostEnvironment env,
+ IDataView input,
+ DataKind resultType,
+ string name,
+ string source = null)
+ : this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } }, ResultType = resultType }, input)
+ {
+ }
+
public ConvertTransform(IHostEnvironment env, Arguments args, IDataView input)
: base(env, RegistrationName, env.CheckRef(args, nameof(args)).Column,
input, null)
diff --git a/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs b/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs
index f80589bdab..713f85f9df 100644
--- a/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs
@@ -77,16 +77,22 @@ private bool TryParse(string str)
}
}
+ private static class Defaults
+ {
+ public const bool UseCounter = false;
+ public const uint Seed = 42;
+ }
+
public sealed class Arguments : TransformInputBase
{
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:seed)", ShortName = "col", SortOrder = 1)]
public Column[] Column;
[Argument(ArgumentType.AtMostOnce, HelpText = "Use an auto-incremented integer starting at zero instead of a random number", ShortName = "cnt")]
- public bool UseCounter;
+ public bool UseCounter = Defaults.UseCounter;
[Argument(ArgumentType.AtMostOnce, HelpText = "The random seed")]
- public uint Seed = 42;
+ public uint Seed = Defaults.Seed;
}
private sealed class Bindings : ColumnBindingsBase
@@ -250,6 +256,18 @@ private static VersionInfo GetVersionInfo()
private const string RegistrationName = "GenerateNumber";
+ ///
+ /// Convenience constructor for public facing API.
+ ///
+ /// Host Environment.
+ /// Input . This is the output from previous transform or loader.
+ /// Name of the output column.
+ /// Use an auto-incremented integer starting at zero instead of a random number.
+ public GenerateNumberTransform(IHostEnvironment env, IDataView input, string name, bool useCounter = Defaults.UseCounter)
+ : this(env, new Arguments() { Column = new[] { new Column() { Name = name } }, UseCounter = useCounter }, input)
+ {
+ }
+
///
/// Public constructor corresponding to SignatureDataTransform.
///
diff --git a/src/Microsoft.ML.Data/Transforms/HashTransform.cs b/src/Microsoft.ML.Data/Transforms/HashTransform.cs
index ca959069f7..23ba5592b7 100644
--- a/src/Microsoft.ML.Data/Transforms/HashTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/HashTransform.cs
@@ -33,6 +33,14 @@ public sealed class HashTransform : OneToOneTransformBase, ITransformTemplate
public const int NumBitsMin = 1;
public const int NumBitsLim = 32;
+ private static class Defaults
+ {
+ public const int HashBits = NumBitsLim - 1;
+ public const uint Seed = 314489979;
+ public const bool Ordered = false;
+ public const int InvertHash = 0;
+ }
+
public sealed class Arguments
{
[Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col",
@@ -41,18 +49,18 @@ public sealed class Arguments
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of bits to hash into. Must be between 1 and 31, inclusive",
ShortName = "bits", SortOrder = 2)]
- public int HashBits = NumBitsLim - 1;
+ public int HashBits = Defaults.HashBits;
[Argument(ArgumentType.AtMostOnce, HelpText = "Hashing seed")]
- public uint Seed = 314489979;
+ public uint Seed = Defaults.Seed;
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether the position of each term should be included in the hash",
ShortName = "ord")]
- public bool Ordered;
+ public bool Ordered = Defaults.Ordered;
[Argument(ArgumentType.AtMostOnce, HelpText = "Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit.",
ShortName = "ih")]
- public int InvertHash;
+ public int InvertHash = Defaults.InvertHash;
}
public sealed class Column : OneToOneColumn
@@ -234,6 +242,27 @@ public override void Save(ModelSaveContext ctx)
TextModelHelper.SaveAll(Host, ctx, Infos.Length, _keyValues);
}
+ ///
+ /// Convenience constructor for public facing API.
+ ///
+ /// Host Environment.
+ /// Input . This is the output from previous transform or loader.
+ /// Name of the output column.
+ /// Name of the column to be transformed. If this is null '' will be used.
+ /// Number of bits to hash into. Must be between 1 and 31, inclusive.
+ /// Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit.
+ public HashTransform(IHostEnvironment env,
+ IDataView input,
+ string name,
+ string source = null,
+ int hashBits = Defaults.HashBits,
+ int invertHash = Defaults.InvertHash)
+ : this(env, new Arguments() {
+ Column = new[] { new Column() { Source = source ?? name, Name = name } },
+ HashBits = hashBits, InvertHash = invertHash }, input)
+ {
+ }
+
public HashTransform(IHostEnvironment env, Arguments args, IDataView input)
: base(Contracts.CheckRef(env, nameof(env)), RegistrationName, env.CheckRef(args, nameof(args)).Column,
input, TestType)
diff --git a/src/Microsoft.ML.Data/Transforms/KeyToValueTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToValueTransform.cs
index 165ab7e0df..7c1fa19c10 100644
--- a/src/Microsoft.ML.Data/Transforms/KeyToValueTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/KeyToValueTransform.cs
@@ -73,6 +73,19 @@ private static VersionInfo GetVersionInfo()
private readonly ColumnType[] _types;
private KeyToValueMap[] _kvMaps;
+ ///
+ /// Convenience constructor for public facing API.
+ ///
+ /// Host Environment.
+ /// Input . This is the output from previous transform or loader.
+ /// Name of the output column.
+ /// Name of the input column. If this is null '' will be used.
+ public KeyToValueTransform(IHostEnvironment env, IDataView input, string name, string source = null)
+ : this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } } }, input)
+ {
+ }
+
+
///
/// Public constructor corresponding to SignatureDataTransform.
///
diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs
index d177d75647..0f4b616a49 100644
--- a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs
@@ -70,6 +70,11 @@ public bool TryUnparse(StringBuilder sb)
}
}
+ private static class Defaults
+ {
+ public const bool Bag = false;
+ }
+
public sealed class Arguments
{
[Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 1)]
@@ -77,7 +82,7 @@ public sealed class Arguments
[Argument(ArgumentType.AtMostOnce,
HelpText = "Whether to combine multiple indicator vectors into a single bag vector instead of concatenating them. This is only relevant when the input is a vector.")]
- public bool Bag;
+ public bool Bag = Defaults.Bag;
}
internal const string Summary = "Converts a key column to an indicator vector.";
@@ -112,6 +117,23 @@ private static VersionInfo GetVersionInfo()
private readonly bool[] _concat;
private readonly VectorType[] _types;
+ ///
+ /// Convenience constructor for public facing API.
+ ///
+ /// Host Environment.
+ /// Input . This is the output from previous transform or loader.
+ /// Name of the output column.
+ /// Name of the input column. If this is null '' will be used.
+ /// Whether to combine multiple indicator vectors into a single bag vector instead of concatenating them. This is only relevant when the input is a vector.
+ public KeyToVectorTransform(IHostEnvironment env,
+ IDataView input,
+ string name,
+ string source = null,
+ bool bag = Defaults.Bag)
+ : this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } }, Bag = bag }, input)
+ {
+ }
+
///
/// Public constructor corresponding to SignatureDataTransform.
///
diff --git a/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs b/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs
index 5329d89a57..8817833f40 100644
--- a/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs
@@ -64,6 +64,18 @@ private static VersionInfo GetVersionInfo()
private const string RegistrationName = "LabelConvert";
private VectorType _slotType;
+ ///
+ /// Convenience constructor for public facing API.
+ ///
+ /// Host Environment.
+ /// Input . This is the output from previous transform or loader.
+ /// Name of the output column.
+ /// Name of the input column. If this is null '' will be used.
+ public LabelConvertTransform(IHostEnvironment env, IDataView input, string name, string source = null)
+ : this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } } }, input)
+ {
+ }
+
public LabelConvertTransform(IHostEnvironment env, Arguments args, IDataView input)
: base(env, RegistrationName, Contracts.CheckRef(args, nameof(args)).Column, input, RowCursorUtils.TestGetLabelGetter)
{
diff --git a/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs b/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs
index 81a91b5f17..a7672b5a1c 100644
--- a/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs
@@ -111,6 +111,23 @@ private static string TestIsMulticlassLabel(ColumnType type)
return $"Label column type is not supported for binary remapping: {type}. Supported types: key, float, double.";
}
+ ///
+ /// Convenience constructor for public facing API.
+ ///
+ /// Host Environment.
+ /// Input . This is the output from previous transform or loader.
+ /// Label of the positive class.
+ /// Name of the output column.
+ /// Name of the input column. If this is null '' will be used.
+ public LabelIndicatorTransform(IHostEnvironment env,
+ IDataView input,
+ int classIndex,
+ string name,
+ string source = null)
+ : this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } }, ClassIndex = classIndex }, input)
+ {
+ }
+
public LabelIndicatorTransform(IHostEnvironment env, Arguments args, IDataView input)
: base(env, LoadName, Contracts.CheckRef(args, nameof(args)).Column,
input, TestIsMulticlassLabel)
diff --git a/src/Microsoft.ML.Data/Transforms/RangeFilter.cs b/src/Microsoft.ML.Data/Transforms/RangeFilter.cs
index b9ab10f4c1..589a635aff 100644
--- a/src/Microsoft.ML.Data/Transforms/RangeFilter.cs
+++ b/src/Microsoft.ML.Data/Transforms/RangeFilter.cs
@@ -77,6 +77,19 @@ private static VersionInfo GetVersionInfo()
private readonly bool _includeMin;
private readonly bool _includeMax;
+ ///
+ /// Convenience constructor for public facing API.
+ ///
+ /// Host Environment.
+ /// Input . This is the output from previous transform or loader.
+ /// Name of the input column.
+ /// Minimum value (0 to 1 for key types).
+ /// Maximum value (0 to 1 for key types).
+ public RangeFilter(IHostEnvironment env, IDataView input, string column, Double? minimum = null, Double? maximum = null)
+ : this(env, new Arguments() { Column = column, Min = minimum, Max = maximum }, input)
+ {
+ }
+
public RangeFilter(IHostEnvironment env, Arguments args, IDataView input)
: base(env, RegistrationName, input)
{
diff --git a/src/Microsoft.ML.Data/Transforms/ShuffleTransform.cs b/src/Microsoft.ML.Data/Transforms/ShuffleTransform.cs
index 37e52ee2da..5080208335 100644
--- a/src/Microsoft.ML.Data/Transforms/ShuffleTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/ShuffleTransform.cs
@@ -33,18 +33,25 @@ namespace Microsoft.ML.Runtime.Data
///
public sealed class ShuffleTransform : RowToRowTransformBase
{
+ private static class Defaults
+ {
+ public const int PoolRows = 1000;
+ public const bool PoolOnly = false;
+ public const bool ForceShuffle = false;
+ }
+
public sealed class Arguments
{
// REVIEW: A more intelligent heuristic, based on the expected size of the inputs, perhaps?
[Argument(ArgumentType.LastOccurenceWins, HelpText = "The pool will have this many rows", ShortName = "rows")]
- public int PoolRows = 1000;
+ public int PoolRows = Defaults.PoolRows;
// REVIEW: Come up with a better way to specify the desired set of functionality.
[Argument(ArgumentType.LastOccurenceWins, HelpText = "If true, the transform will not attempt to shuffle the input cursor but only shuffle based on the pool. This parameter has no effect if the input data was not itself shufflable.", ShortName = "po")]
- public bool PoolOnly;
+ public bool PoolOnly = Defaults.PoolOnly;
[Argument(ArgumentType.LastOccurenceWins, HelpText = "If true, the transform will always provide a shuffled view.", ShortName = "force")]
- public bool ForceShuffle;
+ public bool ForceShuffle = Defaults.ForceShuffle;
[Argument(ArgumentType.LastOccurenceWins, HelpText = "If true, the transform will always shuffle the input. The default value is the same as forceShuffle.", ShortName = "forceSource")]
public bool? ForceShuffleSource;
@@ -79,6 +86,23 @@ private static VersionInfo GetVersionInfo()
// know how to copy other types of values.
private readonly IDataView _subsetInput;
+ ///
+ /// Convenience constructor for public facing API.
+ ///
+ /// Host Environment.
+ /// Input . This is the output from previous transform or loader.
+ /// The pool will have this many rows
+ /// If true, the transform will not attempt to shuffle the input cursor but only shuffle based on the pool. This parameter has no effect if the input data was not itself shufflable.
+ /// If true, the transform will always provide a shuffled view.
+ public ShuffleTransform(IHostEnvironment env,
+ IDataView input,
+ int poolRows = Defaults.PoolRows,
+ bool poolOnly = Defaults.PoolOnly,
+ bool forceShuffle = Defaults.ForceShuffle)
+ : this(env, new Arguments() { PoolRows = poolRows, PoolOnly = poolOnly, ForceShuffle = forceShuffle }, input)
+ {
+ }
+
///
/// Public constructor corresponding to SignatureDataTransform.
///
diff --git a/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs b/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs
index 278f3ee418..bfd3522f73 100644
--- a/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs
+++ b/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs
@@ -60,13 +60,13 @@ public sealed class Arguments : TransformInputBase
public sealed class TakeArguments : TransformInputBase
{
[Argument(ArgumentType.Required, HelpText = Arguments.TakeHelp, ShortName = "c,n,t", SortOrder = 1)]
- public long Count = long.MaxValue;
+ public long Count = Arguments.DefaultTake;
}
public sealed class SkipArguments : TransformInputBase
{
[Argument(ArgumentType.Required, HelpText = Arguments.SkipHelp, ShortName = "c,n,s", SortOrder = 1)]
- public long Count = 0;
+ public long Count = Arguments.DefaultSkip;
}
private static VersionInfo GetVersionInfo()
@@ -270,4 +270,30 @@ protected override bool MoveManyCore(long count)
}
}
}
+
+ public static class SkipFilter
+ {
+ ///
+ /// A helper method to create transform for skipping the number of rows defined by the parameter.
+ /// when created with behaves as 'SkipFilter'.
+ ///
+ /// Host Environment.
+ /// >Input . This is the output from previous transform or loader.
+ /// Number of rows to skip
+ public static IDataTransform Create(IHostEnvironment env, IDataView input, long count = SkipTakeFilter.Arguments.DefaultSkip)
+ => SkipTakeFilter.Create(env, new SkipTakeFilter.SkipArguments() { Count = count }, input);
+ }
+
+ public static class TakeFilter
+ {
+ ///
+ /// A helper method to create transform by taking the top rows defined by the parameter.
+ /// when created with behaves as 'TakeFilter'.
+ ///
+ /// Host Environment.
+ /// >Input . This is the output from previous transform or loader.
+ /// Number of rows to take
+ public static IDataTransform Create(IHostEnvironment env, IDataView input, long count = SkipTakeFilter.Arguments.DefaultTake)
+ => SkipTakeFilter.Create(env, new SkipTakeFilter.TakeArguments() { Count = count }, input);
+ }
}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Data/Transforms/TermTransform.cs b/src/Microsoft.ML.Data/Transforms/TermTransform.cs
index bb9adf21e1..7591179588 100644
--- a/src/Microsoft.ML.Data/Transforms/TermTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/TermTransform.cs
@@ -97,10 +97,16 @@ public enum SortOrder : byte
// other things, like case insensitive (where appropriate), culturally aware, etc.?
}
+ private static class Defaults
+ {
+ public const int MaxNumTerms = 1000000;
+ public const SortOrder Sort = SortOrder.Occurrence;
+ }
+
public abstract class ArgumentsBase : TransformInputBase
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Maximum number of terms to keep per column when auto-training", ShortName = "max", SortOrder = 5)]
- public int MaxNumTerms = 1000000;
+ public int MaxNumTerms = Defaults.MaxNumTerms;
[Argument(ArgumentType.AtMostOnce, HelpText = "Comma separated list of terms", SortOrder = 105, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)]
public string Terms;
@@ -124,7 +130,7 @@ public abstract class ArgumentsBase : TransformInputBase
// REVIEW: Should we always sort? Opinions are mixed. See work item 7797429.
[Argument(ArgumentType.AtMostOnce, HelpText = "How items should be ordered when vectorized. By default, they will be in the order encountered. " +
"If by value items are sorted according to their default comparison, e.g., text sorting will be case sensitive (e.g., 'A' then 'Z' then 'a').", SortOrder = 113)]
- public SortOrder Sort = SortOrder.Occurrence;
+ public SortOrder Sort = Defaults.Sort;
// REVIEW: Should we do this here, or correct the various pieces of code here and in MRS etc. that
// assume key-values will be string? Once we correct these things perhaps we can see about removing it.
@@ -196,6 +202,26 @@ private CodecFactory CodecFactory
public override bool CanSavePfa => true;
public override bool CanSaveOnnx => true;
+ ///
+ /// Convenience constructor for public facing API.
+ ///
+ /// Host Environment.
+ /// Input . This is the output from previous transform or loader.
+ /// Name of the output column.
+ /// Name of the column to be transformed. If this is null '' will be used.
+ /// Maximum number of terms to keep per column when auto-training.
+ /// How items should be ordered when vectorized. By default, they will be in the order encountered.
+ /// If by value items are sorted according to their default comparison, e.g., text sorting will be case sensitive (e.g., 'A' then 'Z' then 'a').
+ public TermTransform(IHostEnvironment env,
+ IDataView input,
+ string name,
+ string source = null,
+ int maxNumTerms = Defaults.MaxNumTerms,
+ SortOrder sort = Defaults.Sort)
+ : this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } }, MaxNumTerms = maxNumTerms, Sort = sort }, input)
+ {
+ }
+
///
/// Public constructor corresponding to SignatureDataTransform.
///