diff --git a/src/Microsoft.ML.Transforms/CategoricalCatalog.cs b/src/Microsoft.ML.Transforms/CategoricalCatalog.cs
index cdb526b0a2..9f3b3b6547 100644
--- a/src/Microsoft.ML.Transforms/CategoricalCatalog.cs
+++ b/src/Microsoft.ML.Transforms/CategoricalCatalog.cs
@@ -43,13 +43,17 @@ public static OneHotEncodingEstimator OneHotEncoding(this TransformsCatalog.Cate
/// The transform catalog
/// The input column
/// The output column. If null, is used.
+ /// Number of bits to hash into. Must be between 1 and 30, inclusive.
+ /// Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit.
/// The conversion mode.
///
public static OneHotHashEncodingEstimator OneHotHashEncoding(this TransformsCatalog.CategoricalTransforms catalog,
string inputColumn,
string outputColumn = null,
+ int hashBits = OneHotHashEncodingEstimator.Defaults.HashBits,
+ int invertHash = OneHotHashEncodingEstimator.Defaults.InvertHash,
OneHotEncodingTransformer.OutputKind outputKind = OneHotEncodingTransformer.OutputKind.Ind)
- => new OneHotHashEncodingEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn, outputKind);
+ => new OneHotHashEncodingEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn, hashBits, invertHash, outputKind);
///
/// Convert several text column into hash-based one-hot encoded vectors.
diff --git a/src/Microsoft.ML.Transforms/OneHotHashEncodingTransformer.cs b/src/Microsoft.ML.Transforms/OneHotHashEncodingTransformer.cs
index 41fa1bc855..8bd7856c7e 100644
--- a/src/Microsoft.ML.Transforms/OneHotHashEncodingTransformer.cs
+++ b/src/Microsoft.ML.Transforms/OneHotHashEncodingTransformer.cs
@@ -145,7 +145,7 @@ public static IDataView Create(IHostEnvironment env,
int invertHash = OneHotHashEncodingEstimator.Defaults.InvertHash,
OneHotEncodingTransformer.OutputKind outputKind = OneHotHashEncodingEstimator.Defaults.OutputKind)
{
- return new OneHotHashEncodingEstimator(env, name, source, outputKind).Fit(input).Transform(input) as IDataView;
+ return new OneHotHashEncodingEstimator(env, name, source, hashBits, invertHash, outputKind).Fit(input).Transform(input) as IDataView;
}
internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
@@ -240,10 +240,16 @@ public ColumnInfo(string input, string output,
/// Host Environment.
/// Name of the input column.
/// Name of the output column. If this is null '' will be used.
+ /// Number of bits to hash into. Must be between 1 and 30, inclusive.
+ /// Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit.
/// The type of output expected.
- public OneHotHashEncodingEstimator(IHostEnvironment env, string inputColumn,
- string outputColumn = null, OneHotEncodingTransformer.OutputKind outputKind = Defaults.OutputKind)
- : this(env, new ColumnInfo(inputColumn, outputColumn ?? inputColumn, outputKind))
+ public OneHotHashEncodingEstimator(IHostEnvironment env,
+ string inputColumn,
+ string outputColumn,
+ int hashBits = OneHotHashEncodingEstimator.Defaults.HashBits,
+ int invertHash = OneHotHashEncodingEstimator.Defaults.InvertHash,
+ OneHotEncodingTransformer.OutputKind outputKind = Defaults.OutputKind)
+ : this(env, new ColumnInfo(inputColumn, outputColumn ?? inputColumn, outputKind, hashBits, invertHash: invertHash))
{
}
diff --git a/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs b/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs
index 7dc1e30f0f..f9d7f58876 100644
--- a/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs
+++ b/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs
@@ -64,6 +64,20 @@ public void CategoricalWorkout()
Done();
}
+ [Fact]
+ public void CategoricalOneHotHashEncoding()
+ {
+ var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } };
+
+ var mlContext = new MLContext();
+ var dataView = ComponentCreation.CreateDataView(mlContext, data);
+
+ var pipe = mlContext.Transforms.Categorical.OneHotHashEncoding("A", "CatA", 16, 0, OneHotEncodingTransformer.OutputKind.Bag);
+
+ TestEstimatorCore(pipe, dataView);
+ Done();
+ }
+
[Fact]
public void CategoricalStatic()
{