Skip to content

Hash estimator #944

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Sep 20, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Internal.Calibration;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Transforms;

[assembly: LoadableClass(typeof(CrossValidationCommand), typeof(CrossValidationCommand.Arguments), typeof(SignatureCommand),
"Cross Validation", CrossValidationCommand.LoadName)]
Expand Down Expand Up @@ -329,10 +330,7 @@ private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output
int inc = 0;
while (input.Schema.TryGetColumnIndex(stratificationColumn, out tmp))
stratificationColumn = string.Format("{0}_{1:000}", origStratCol, ++inc);
var hashargs = new HashTransform.Arguments();
hashargs.Column = new[] { new HashTransform.Column { Source = origStratCol, Name = stratificationColumn } };
hashargs.HashBits = 30;
output = new HashTransform(Host, hashargs, input);
output = new HashEstimator(Host, origStratCol, stratificationColumn, 30).Fit(input).Transform(input);
}
}

Expand Down
719 changes: 415 additions & 304 deletions src/Microsoft.ML.Data/Transforms/HashTransform.cs

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions src/Microsoft.ML.Data/Transforms/TermEstimator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ public static class Defaults
/// Convenience constructor for public facing API.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="name">Name of the output column.</param>
/// <param name="source">Name of the column to be transformed. If this is null '<paramref name="name"/>' will be used.</param>
/// <param name="inputColumn">Name of the output column.</param>
/// <param name="outputColumn">Name of the column to be transformed. If this is null '<paramref name="inputColumn"/>' will be used.</param>
/// <param name="maxNumTerms">Maximum number of terms to keep per column when auto-training.</param>
/// <param name="sort">How items should be ordered when vectorized. By default, they will be in the order encountered.
/// If by value items are sorted according to their default comparison, e.g., text sorting will be case sensitive (e.g., 'A' then 'Z' then 'a').</param>
public TermEstimator(IHostEnvironment env, string name, string source = null, int maxNumTerms = Defaults.MaxNumTerms, TermTransform.SortOrder sort = Defaults.Sort) :
this(env, new TermTransform.ColumnInfo(name, source ?? name, maxNumTerms, sort))
public TermEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null, int maxNumTerms = Defaults.MaxNumTerms, TermTransform.SortOrder sort = Defaults.Sort) :
this(env, new TermTransform.ColumnInfo(inputColumn, outputColumn ?? inputColumn, maxNumTerms, sort))
{
}

Expand Down
11 changes: 6 additions & 5 deletions src/Microsoft.ML.Transforms/CategoricalHashTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Transforms;

[assembly: LoadableClass(CategoricalHashTransform.Summary, typeof(IDataTransform), typeof(CategoricalHashTransform), typeof(CategoricalHashTransform.Arguments), typeof(SignatureDataTransform),
CategoricalHashTransform.UserName, "CategoricalHashTransform", "CatHashTransform", "CategoricalHash", "CatHash")]
Expand Down Expand Up @@ -91,7 +92,7 @@ private static class Defaults
}

/// <summary>
/// This class is a merger of <see cref="HashTransform.Arguments"/> and <see cref="KeyToVectorTransform.Arguments"/>
/// This class is a merger of <see cref="HashTransformer.Arguments"/> and <see cref="KeyToVectorTransform.Arguments"/>
/// with join option removed
/// </summary>
public sealed class Arguments : TransformInputBase
Expand Down Expand Up @@ -169,13 +170,13 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
throw h.ExceptUserArg(nameof(args.HashBits), "Number of bits must be between 1 and {0}", NumBitsLim - 1);

// creating the Hash function
var hashArgs = new HashTransform.Arguments
var hashArgs = new HashTransformer.Arguments
{
HashBits = args.HashBits,
Seed = args.Seed,
Ordered = args.Ordered,
InvertHash = args.InvertHash,
Column = new HashTransform.Column[args.Column.Length]
Column = new HashTransformer.Column[args.Column.Length]
};
for (int i = 0; i < args.Column.Length; i++)
{
Expand All @@ -184,7 +185,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
throw h.ExceptUserArg(nameof(Column.Name));
h.Assert(!string.IsNullOrWhiteSpace(column.Name));
h.Assert(!string.IsNullOrWhiteSpace(column.Source));
hashArgs.Column[i] = new HashTransform.Column
hashArgs.Column[i] = new HashTransformer.Column
{
HashBits = column.HashBits,
Seed = column.Seed,
Expand All @@ -198,7 +199,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
return CreateTransformCore(
args.OutputKind, args.Column,
args.Column.Select(col => col.OutputKind).ToList(),
new HashTransform(h, hashArgs, input),
HashTransformer.Create(h, hashArgs, input),
h,
args);
}
Expand Down
3 changes: 2 additions & 1 deletion src/Microsoft.ML.Transforms/Text/WordBagTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Transforms;

[assembly: LoadableClass(WordBagTransform.Summary, typeof(IDataTransform), typeof(WordBagTransform), typeof(WordBagTransform.Arguments), typeof(SignatureDataTransform),
"Word Bag Transform", "WordBagTransform", "WordBag")]
Expand Down Expand Up @@ -474,7 +475,7 @@ public interface INgramExtractorFactory
{
/// <summary>
/// Whether the extractor transform created by this factory uses the hashing trick
/// (by using <see cref="HashTransform"/> or <see cref="NgramHashTransform"/>, for example).
/// (by using <see cref="HashTransformer"/> or <see cref="NgramHashTransform"/>, for example).
/// </summary>
bool UseHashingTrick { get; }

Expand Down
11 changes: 6 additions & 5 deletions src/Microsoft.ML.Transforms/Text/WordHashBagTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Transforms;

[assembly: LoadableClass(WordHashBagTransform.Summary, typeof(IDataTransform), typeof(WordHashBagTransform), typeof(WordHashBagTransform.Arguments), typeof(SignatureDataTransform),
"Word Hash Bag Transform", "WordHashBagTransform", "WordHashBag")]
Expand Down Expand Up @@ -266,7 +267,7 @@ public bool TryUnparse(StringBuilder sb)
}

/// <summary>
/// This class is a merger of <see cref="HashTransform.Arguments"/> and
/// This class is a merger of <see cref="HashTransformer.Arguments"/> and
/// <see cref="NgramHashTransform.Arguments"/>, with the ordered option,
/// the rehashUnigrams option and the allLength option removed.
/// </summary>
Expand Down Expand Up @@ -340,7 +341,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
List<TermTransform.Column> termCols = null;
if (termLoaderArgs != null)
termCols = new List<TermTransform.Column>();
var hashColumns = new List<HashTransform.Column>();
var hashColumns = new List<HashTransformer.Column>();
var ngramHashColumns = new NgramHashTransform.Column[args.Column.Length];

var colCount = args.Column.Length;
Expand Down Expand Up @@ -371,7 +372,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
}

hashColumns.Add(
new HashTransform.Column
new HashTransformer.Column
{
Name = tmpName,
Source = termLoaderArgs == null ? column.Source[isrc] : tmpName,
Expand Down Expand Up @@ -435,7 +436,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV

// Args for the Hash function with multiple columns
var hashArgs =
new HashTransform.Arguments
new HashTransformer.Arguments
{
HashBits = 31,
Seed = args.Seed,
Expand All @@ -444,7 +445,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
InvertHash = args.InvertHash
};

view = new HashTransform(h, hashArgs, view);
view = HashTransformer.Create(h, hashArgs, view);

// creating the NgramHash function
var ngramHashArgs =
Expand Down
32 changes: 13 additions & 19 deletions test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,12 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Float = System.Single;

using System;
using System.Collections.Generic;
using System.IO;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Data.IO;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Model;
using Microsoft.ML.Runtime.TextAnalytics;
using Microsoft.ML.Transforms;
using System;
using Xunit;
using Float = System.Single;

namespace Microsoft.ML.Runtime.RunTests
{
Expand Down Expand Up @@ -82,14 +76,14 @@ private void TestHashTransformHelper<T>(T[] data, uint[] results, NumberType typ
builder.AddColumn("F1", type, data);
var srcView = builder.GetDataView();

HashTransform.Column col = new HashTransform.Column();
col.Source = "F1";
var col = new HashTransformer.Column();
col.Name = "F1";
col.HashBits = 5;
col.Seed = 42;
HashTransform.Arguments args = new HashTransform.Arguments();
args.Column = new HashTransform.Column[] { col };
var args = new HashTransformer.Arguments();
args.Column = new HashTransformer.Column[] { col };

var hashTransform = new HashTransform(Env, args, srcView);
var hashTransform = HashTransformer.Create(Env, args, srcView);
using (var cursor = hashTransform.GetRowCursor(c => true))
{
var resultGetter = cursor.GetGetter<uint>(1);
Expand Down Expand Up @@ -120,14 +114,14 @@ private void TestHashTransformVectorHelper<T>(VBuffer<T> data, uint[][] results,
private void TestHashTransformVectorHelper(ArrayDataViewBuilder builder, uint[][] results)
{
var srcView = builder.GetDataView();
HashTransform.Column col = new HashTransform.Column();
col.Source = "F1V";
var col = new HashTransformer.Column();
col.Name = "F1V";
col.HashBits = 5;
col.Seed = 42;
HashTransform.Arguments args = new HashTransform.Arguments();
args.Column = new HashTransform.Column[] { col };
var args = new HashTransformer.Arguments();
args.Column = new HashTransformer.Column[] { col };

var hashTransform = new HashTransform(Env, args, srcView);
var hashTransform = HashTransformer.Create(Env, args, srcView);
using (var cursor = hashTransform.GetRowCursor(c => true))
{
var resultGetter = cursor.GetGetter<VBuffer<uint>>(1);
Expand Down
134 changes: 134 additions & 0 deletions test/Microsoft.ML.Tests/Transformers/HashTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Runtime.Api;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Model;
using Microsoft.ML.Runtime.RunTests;
using Microsoft.ML.Runtime.Tools;
using Microsoft.ML.Transforms;
using System;
using System.IO;
using System.Linq;
using Xunit;
using Xunit.Abstractions;

namespace Microsoft.ML.Tests.Transformers
{
public class HashTests : TestDataPipeBase
{
public HashTests(ITestOutputHelper output) : base(output)
{
}

private class TestClass
{
public float A;
public float B;
public float C;
}

private class TestMeta
{
[VectorType(2)]
public float[] A;
public float B;
[VectorType(2)]
public double[] C;
public double D;
}

[Fact]
public void HashWorkout()
{
var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } };

var dataView = ComponentCreation.CreateDataView(Env, data);
var pipe = new HashEstimator(Env, new[]{
new HashTransformer.ColumnInfo("A", "HashA", hashBits:4, invertHash:-1),
new HashTransformer.ColumnInfo("B", "HashB", hashBits:3, ordered:true),
new HashTransformer.ColumnInfo("C", "HashC", seed:42),
new HashTransformer.ColumnInfo("A", "HashD"),
});

TestEstimatorCore(pipe, dataView);
Done();
}

[Fact]
public void TestMetadata()
{

var data = new[] {
new TestMeta() { A=new float[2] { 3.5f, 2.5f}, B=1, C= new double[2] { 5.1f, 6.1f}, D= 7},
new TestMeta() { A=new float[2] { 3.5f, 2.5f}, B=1, C= new double[2] { 5.1f, 6.1f}, D= 7},
new TestMeta() { A=new float[2] { 3.5f, 2.5f}, B=1, C= new double[2] { 5.1f, 6.1f}, D= 7}};


var dataView = ComponentCreation.CreateDataView(Env, data);
var pipe = new HashEstimator(Env, new[] {
new HashTransformer.ColumnInfo("A", "HashA", invertHash:1, hashBits:10),
new HashTransformer.ColumnInfo("A", "HashAUnlim", invertHash:-1, hashBits:10),
new HashTransformer.ColumnInfo("A", "HashAUnlimOrdered", invertHash:-1, hashBits:10, ordered:true)
});
var result = pipe.Fit(dataView).Transform(dataView);
ValidateMetadata(result);
Done();
}

private void ValidateMetadata(IDataView result)
{

Assert.True(result.Schema.TryGetColumnIndex("HashA", out int HashA));
Assert.True(result.Schema.TryGetColumnIndex("HashAUnlim", out int HashAUnlim));
Assert.True(result.Schema.TryGetColumnIndex("HashAUnlimOrdered", out int HashAUnlimOrdered));
VBuffer<ReadOnlyMemory<char>> keys = default;
var types = result.Schema.GetMetadataTypes(HashA);
Assert.Equal(types.Select(x => x.Key), new string[1] { MetadataUtils.Kinds.KeyValues });
result.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, HashA, ref keys);
Assert.True(keys.Length == 1024);
//REVIEW: This is weird. I specified invertHash to 1 so I expect only one value to be in key values, but i got two.
Assert.Equal(keys.Items().Select(x => x.Value.ToString()), new string[2] {"2.5", "3.5" });

types = result.Schema.GetMetadataTypes(HashAUnlim);
Assert.Equal(types.Select(x => x.Key), new string[1] { MetadataUtils.Kinds.KeyValues });
result.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, HashA, ref keys);
Assert.True(keys.Length == 1024);
Assert.Equal(keys.Items().Select(x => x.Value.ToString()), new string[2] { "2.5", "3.5" });

types = result.Schema.GetMetadataTypes(HashAUnlimOrdered);
Assert.Equal(types.Select(x => x.Key), new string[1] { MetadataUtils.Kinds.KeyValues });
result.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, HashA, ref keys);
Assert.True(keys.Length == 1024);
Assert.Equal(keys.Items().Select(x => x.Value.ToString()), new string[2] { "2.5", "3.5" });
}

[Fact]
public void TestCommandLine()
{
Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Hash{col=B:A} in=f:\2.txt" }), (int)0);
}

[Fact]
public void TestOldSavingAndLoading()
{
var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } };
var dataView = ComponentCreation.CreateDataView(Env, data);
var pipe = new HashEstimator(Env, new[]{
new HashTransformer.ColumnInfo("A", "HashA", hashBits:4, invertHash:-1),
new HashTransformer.ColumnInfo("B", "HashB", hashBits:3, ordered:true),
new HashTransformer.ColumnInfo("C", "HashC", seed:42),
new HashTransformer.ColumnInfo("A", "HashD"),
});
var result = pipe.Fit(dataView).Transform(dataView);
var resultRoles = new RoleMappedData(result);
using (var ms = new MemoryStream())
{
TrainUtils.SaveModel(Env, Env.Start("saving"), ms, null, resultRoles);
ms.Position = 0;
var loadedView = ModelFileUtils.LoadTransforms(Env, dataView, ms);
}
}
}
}