Skip to content

Commit e090660

Browse files
feiyun0112Ivanidzo4ka
authored andcommitted
Pass hashBits, invertHash to OneHotHashEncodingEstimator (#1564)
1 parent e6e07ed commit e090660

File tree

3 files changed

+29
-5
lines changed

3 files changed

+29
-5
lines changed

src/Microsoft.ML.Transforms/CategoricalCatalog.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,17 @@ public static OneHotEncodingEstimator OneHotEncoding(this TransformsCatalog.Cate
4343
/// <param name="catalog">The transform catalog</param>
4444
/// <param name="inputColumn">The input column</param>
4545
/// <param name="outputColumn">The output column. If <c>null</c>, <paramref name="inputColumn"/> is used.</param>
46+
/// <param name="hashBits">Number of bits to hash into. Must be between 1 and 30, inclusive.</param>
47+
/// <param name="invertHash">Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit.</param>
4648
/// <param name="outputKind">The conversion mode.</param>
4749
/// <returns></returns>
4850
public static OneHotHashEncodingEstimator OneHotHashEncoding(this TransformsCatalog.CategoricalTransforms catalog,
4951
string inputColumn,
5052
string outputColumn = null,
53+
int hashBits = OneHotHashEncodingEstimator.Defaults.HashBits,
54+
int invertHash = OneHotHashEncodingEstimator.Defaults.InvertHash,
5155
OneHotEncodingTransformer.OutputKind outputKind = OneHotEncodingTransformer.OutputKind.Ind)
52-
=> new OneHotHashEncodingEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn, outputKind);
56+
=> new OneHotHashEncodingEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn, hashBits, invertHash, outputKind);
5357

5458
/// <summary>
5559
/// Convert several text column into hash-based one-hot encoded vectors.

src/Microsoft.ML.Transforms/OneHotHashEncodingTransformer.cs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ public static IDataView Create(IHostEnvironment env,
145145
int invertHash = OneHotHashEncodingEstimator.Defaults.InvertHash,
146146
OneHotEncodingTransformer.OutputKind outputKind = OneHotHashEncodingEstimator.Defaults.OutputKind)
147147
{
148-
return new OneHotHashEncodingEstimator(env, name, source, outputKind).Fit(input).Transform(input) as IDataView;
148+
return new OneHotHashEncodingEstimator(env, name, source, hashBits, invertHash, outputKind).Fit(input).Transform(input) as IDataView;
149149
}
150150

151151
internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
@@ -240,10 +240,16 @@ public ColumnInfo(string input, string output,
240240
/// <param name="env">Host Environment.</param>
241241
/// <param name="inputColumn">Name of the input column.</param>
242242
/// <param name="outputColumn">Name of the output column. If this is null '<paramref name="inputColumn"/>' will be used.</param>
243+
/// <param name="hashBits">Number of bits to hash into. Must be between 1 and 30, inclusive.</param>
244+
/// <param name="invertHash">Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit.</param>
243245
/// <param name="outputKind">The type of output expected.</param>
244-
public OneHotHashEncodingEstimator(IHostEnvironment env, string inputColumn,
245-
string outputColumn = null, OneHotEncodingTransformer.OutputKind outputKind = Defaults.OutputKind)
246-
: this(env, new ColumnInfo(inputColumn, outputColumn ?? inputColumn, outputKind))
246+
public OneHotHashEncodingEstimator(IHostEnvironment env,
247+
string inputColumn,
248+
string outputColumn,
249+
int hashBits = OneHotHashEncodingEstimator.Defaults.HashBits,
250+
int invertHash = OneHotHashEncodingEstimator.Defaults.InvertHash,
251+
OneHotEncodingTransformer.OutputKind outputKind = Defaults.OutputKind)
252+
: this(env, new ColumnInfo(inputColumn, outputColumn ?? inputColumn, outputKind, hashBits, invertHash: invertHash))
247253
{
248254
}
249255

test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,20 @@ public void CategoricalWorkout()
6464
Done();
6565
}
6666

67+
[Fact]
68+
public void CategoricalOneHotHashEncoding()
69+
{
70+
var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } };
71+
72+
var mlContext = new MLContext();
73+
var dataView = ComponentCreation.CreateDataView(mlContext, data);
74+
75+
var pipe = mlContext.Transforms.Categorical.OneHotHashEncoding("A", "CatA", 16, 0, OneHotEncodingTransformer.OutputKind.Bag);
76+
77+
TestEstimatorCore(pipe, dataView);
78+
Done();
79+
}
80+
6781
[Fact]
6882
public void CategoricalStatic()
6983
{

0 commit comments

Comments
 (0)