diff --git a/src/Microsoft.ML.Transforms/CategoricalCatalog.cs b/src/Microsoft.ML.Transforms/CategoricalCatalog.cs index cdb526b0a2..9f3b3b6547 100644 --- a/src/Microsoft.ML.Transforms/CategoricalCatalog.cs +++ b/src/Microsoft.ML.Transforms/CategoricalCatalog.cs @@ -43,13 +43,17 @@ public static OneHotEncodingEstimator OneHotEncoding(this TransformsCatalog.Cate /// The transform catalog /// The input column /// The output column. If null, is used. + /// Number of bits to hash into. Must be between 1 and 30, inclusive. + /// Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit. /// The conversion mode. /// public static OneHotHashEncodingEstimator OneHotHashEncoding(this TransformsCatalog.CategoricalTransforms catalog, string inputColumn, string outputColumn = null, + int hashBits = OneHotHashEncodingEstimator.Defaults.HashBits, + int invertHash = OneHotHashEncodingEstimator.Defaults.InvertHash, OneHotEncodingTransformer.OutputKind outputKind = OneHotEncodingTransformer.OutputKind.Ind) - => new OneHotHashEncodingEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn, outputKind); + => new OneHotHashEncodingEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn, hashBits, invertHash, outputKind); /// /// Convert several text column into hash-based one-hot encoded vectors. diff --git a/src/Microsoft.ML.Transforms/OneHotHashEncodingTransformer.cs b/src/Microsoft.ML.Transforms/OneHotHashEncodingTransformer.cs index 41fa1bc855..8bd7856c7e 100644 --- a/src/Microsoft.ML.Transforms/OneHotHashEncodingTransformer.cs +++ b/src/Microsoft.ML.Transforms/OneHotHashEncodingTransformer.cs @@ -145,7 +145,7 @@ public static IDataView Create(IHostEnvironment env, int invertHash = OneHotHashEncodingEstimator.Defaults.InvertHash, OneHotEncodingTransformer.OutputKind outputKind = OneHotHashEncodingEstimator.Defaults.OutputKind) { - return new OneHotHashEncodingEstimator(env, name, source, outputKind).Fit(input).Transform(input) as IDataView; + return new OneHotHashEncodingEstimator(env, name, source, hashBits, invertHash, outputKind).Fit(input).Transform(input) as IDataView; } internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) @@ -240,10 +240,16 @@ public ColumnInfo(string input, string output, /// Host Environment. /// Name of the input column. /// Name of the output column. If this is null '' will be used. + /// Number of bits to hash into. Must be between 1 and 30, inclusive. + /// Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit. /// The type of output expected. - public OneHotHashEncodingEstimator(IHostEnvironment env, string inputColumn, - string outputColumn = null, OneHotEncodingTransformer.OutputKind outputKind = Defaults.OutputKind) - : this(env, new ColumnInfo(inputColumn, outputColumn ?? inputColumn, outputKind)) + public OneHotHashEncodingEstimator(IHostEnvironment env, + string inputColumn, + string outputColumn, + int hashBits = OneHotHashEncodingEstimator.Defaults.HashBits, + int invertHash = OneHotHashEncodingEstimator.Defaults.InvertHash, + OneHotEncodingTransformer.OutputKind outputKind = Defaults.OutputKind) + : this(env, new ColumnInfo(inputColumn, outputColumn ?? inputColumn, outputKind, hashBits, invertHash: invertHash)) { } diff --git a/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs b/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs index 7dc1e30f0f..f9d7f58876 100644 --- a/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs @@ -64,6 +64,20 @@ public void CategoricalWorkout() Done(); } + [Fact] + public void CategoricalOneHotHashEncoding() + { + var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; + + var mlContext = new MLContext(); + var dataView = ComponentCreation.CreateDataView(mlContext, data); + + var pipe = mlContext.Transforms.Categorical.OneHotHashEncoding("A", "CatA", 16, 0, OneHotEncodingTransformer.OutputKind.Bag); + + TestEstimatorCore(pipe, dataView); + Done(); + } + [Fact] public void CategoricalStatic() {