Skip to content

Pass hashBits, invertHash to OneHotHashEncodingEstimator #1564

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 18, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/Microsoft.ML.Transforms/CategoricalCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,17 @@ public static OneHotEncodingEstimator OneHotEncoding(this TransformsCatalog.Cate
/// <param name="catalog">The transform catalog</param>
/// <param name="inputColumn">The input column</param>
/// <param name="outputColumn">The output column. If <c>null</c>, <paramref name="inputColumn"/> is used.</param>
/// <param name="hashBits">Number of bits to hash into. Must be between 1 and 30, inclusive.</param>
/// <param name="invertHash">Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit.</param>
/// <param name="outputKind">The conversion mode.</param>
/// <returns></returns>
public static OneHotHashEncodingEstimator OneHotHashEncoding(this TransformsCatalog.CategoricalTransforms catalog,
string inputColumn,
string outputColumn = null,
int hashBits = OneHotHashEncodingEstimator.Defaults.HashBits,
int invertHash = OneHotHashEncodingEstimator.Defaults.InvertHash,
OneHotEncodingTransformer.OutputKind outputKind = OneHotEncodingTransformer.OutputKind.Ind)
=> new OneHotHashEncodingEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn, outputKind);
=> new OneHotHashEncodingEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn, hashBits, invertHash, outputKind);

/// <summary>
/// Convert several text column into hash-based one-hot encoded vectors.
Expand Down
14 changes: 10 additions & 4 deletions src/Microsoft.ML.Transforms/OneHotHashEncodingTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ public static IDataView Create(IHostEnvironment env,
int invertHash = OneHotHashEncodingEstimator.Defaults.InvertHash,
OneHotEncodingTransformer.OutputKind outputKind = OneHotHashEncodingEstimator.Defaults.OutputKind)
{
return new OneHotHashEncodingEstimator(env, name, source, outputKind).Fit(input).Transform(input) as IDataView;
return new OneHotHashEncodingEstimator(env, name, source, hashBits, invertHash, outputKind).Fit(input).Transform(input) as IDataView;
}

internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
Expand Down Expand Up @@ -240,10 +240,16 @@ public ColumnInfo(string input, string output,
/// <param name="env">Host Environment.</param>
/// <param name="inputColumn">Name of the input column.</param>
/// <param name="outputColumn">Name of the output column. If this is null '<paramref name="inputColumn"/>' will be used.</param>
/// <param name="hashBits">Number of bits to hash into. Must be between 1 and 30, inclusive.</param>
/// <param name="invertHash">Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit.</param>
/// <param name="outputKind">The type of output expected.</param>
public OneHotHashEncodingEstimator(IHostEnvironment env, string inputColumn,
string outputColumn = null, OneHotEncodingTransformer.OutputKind outputKind = Defaults.OutputKind)
: this(env, new ColumnInfo(inputColumn, outputColumn ?? inputColumn, outputKind))
public OneHotHashEncodingEstimator(IHostEnvironment env,
string inputColumn,
string outputColumn,
int hashBits = OneHotHashEncodingEstimator.Defaults.HashBits,
int invertHash = OneHotHashEncodingEstimator.Defaults.InvertHash,
OneHotEncodingTransformer.OutputKind outputKind = Defaults.OutputKind)
: this(env, new ColumnInfo(inputColumn, outputColumn ?? inputColumn, outputKind, hashBits, invertHash: invertHash))
{
}

Expand Down
14 changes: 14 additions & 0 deletions test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,20 @@ public void CategoricalWorkout()
Done();
}

[Fact]
public void CategoricalOneHotHashEncoding()
{
var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } };

var mlContext = new MLContext();
var dataView = ComponentCreation.CreateDataView(mlContext, data);

var pipe = mlContext.Transforms.Categorical.OneHotHashEncoding("A", "CatA", 16, 0, OneHotEncodingTransformer.OutputKind.Bag);

TestEstimatorCore(pipe, dataView);
Done();
}

[Fact]
public void CategoricalStatic()
{
Expand Down