From bd01a0cda4bacd6c8588de0e2106fb898c8b8079 Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Tue, 18 Sep 2018 17:24:37 -0700 Subject: [PATCH 1/8] no pigsty so far --- .../Commands/CrossValidationCommand.cs | 2 +- .../Transforms/HashTransform.cs | 687 ++++++++++-------- .../CategoricalHashTransform.cs | 2 +- .../Text/WordHashBagTransform.cs | 2 +- .../DataPipe/TestDataPipe.cs | 4 +- .../Transformers/HashTests.cs | 59 ++ 6 files changed, 451 insertions(+), 305 deletions(-) create mode 100644 test/Microsoft.ML.Tests/Transformers/HashTests.cs diff --git a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs index b0f2e12ef1..f1c720ceef 100644 --- a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs +++ b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs @@ -332,7 +332,7 @@ private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output var hashargs = new HashTransform.Arguments(); hashargs.Column = new[] { new HashTransform.Column { Source = origStratCol, Name = stratificationColumn } }; hashargs.HashBits = 30; - output = new HashTransform(Host, hashargs, input); + output = HashTransform.Create(Host, hashargs, input); } } diff --git a/src/Microsoft.ML.Data/Transforms/HashTransform.cs b/src/Microsoft.ML.Data/Transforms/HashTransform.cs index e0aa599e49..61413b7473 100644 --- a/src/Microsoft.ML.Data/Transforms/HashTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/HashTransform.cs @@ -7,18 +7,25 @@ using System.Linq; using System.Runtime.CompilerServices; using System.Text; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; -[assembly: LoadableClass(HashTransform.Summary, typeof(HashTransform), typeof(HashTransform.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(HashTransform.Summary, typeof(IDataTransform), typeof(HashTransform), typeof(HashTransform.Arguments), typeof(SignatureDataTransform), "Hash Transform", "HashTransform", "Hash", DocName = "transform/HashTransform.md")] -[assembly: LoadableClass(HashTransform.Summary, typeof(HashTransform), null, typeof(SignatureLoadDataTransform), +[assembly: LoadableClass(HashTransform.Summary, typeof(IDataTransform), typeof(HashTransform), null, typeof(SignatureLoadDataTransform), "Hash Transform", HashTransform.LoaderSignature)] +[assembly: LoadableClass(HashTransform.Summary, typeof(HashTransform), null, typeof(SignatureLoadModel), + "Hash Transform", HashTransform.LoaderSignature)] + +[assembly: LoadableClass(typeof(IRowMapper), typeof(HashTransform), null, typeof(SignatureLoadRowMapper), + "Hash Transform", HashTransform.LoaderSignature)] + namespace Microsoft.ML.Runtime.Data { using Conditional = System.Diagnostics.ConditionalAttribute; @@ -28,19 +35,8 @@ namespace Microsoft.ML.Runtime.Data /// it hashes each slot separately. /// It can hash either text values or key values. /// - public sealed class HashTransform : OneToOneTransformBase, ITransformTemplate + public sealed class HashTransform : OneToOneTransformerBase { - public const int NumBitsMin = 1; - public const int NumBitsLim = 32; - - private static class Defaults - { - public const int HashBits = NumBitsLim - 1; - public const uint Seed = 314489979; - public const bool Ordered = false; - public const int InvertHash = 0; - } - public sealed class Arguments { [Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", @@ -49,18 +45,18 @@ public sealed class Arguments [Argument(ArgumentType.AtMostOnce, HelpText = "Number of bits to hash into. Must be between 1 and 31, inclusive", ShortName = "bits", SortOrder = 2)] - public int HashBits = Defaults.HashBits; + public int HashBits = HashEstimator.Defaults.HashBits; [Argument(ArgumentType.AtMostOnce, HelpText = "Hashing seed")] - public uint Seed = Defaults.Seed; + public uint Seed = HashEstimator.Defaults.Seed; [Argument(ArgumentType.AtMostOnce, HelpText = "Whether the position of each term should be included in the hash", ShortName = "ord")] - public bool Ordered = Defaults.Ordered; + public bool Ordered = HashEstimator.Defaults.Ordered; [Argument(ArgumentType.AtMostOnce, HelpText = "Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit.", ShortName = "ih")] - public int InvertHash = Defaults.InvertHash; + public int InvertHash = HashEstimator.Defaults.InvertHash; } public sealed class Column : OneToOneColumn @@ -95,8 +91,7 @@ protected override bool TryParse(string str) // We accept N:B:S where N is the new column name, B is the number of bits, // and S is source column names. - string extra; - if (!base.TryParse(str, out extra)) + if (!base.TryParse(str, out string extra)) return false; if (extra == null) return true; @@ -119,52 +114,68 @@ public bool TryUnparse(StringBuilder sb) return TryUnparseCore(sb, extra); } } - - private sealed class ColInfoEx + public sealed class ColumnInfo { + public readonly string Input; + public readonly string Output; public readonly int HashBits; - public readonly uint HashSeed; + public readonly uint Seed; public readonly bool Ordered; + public readonly int InvertHash; - public ColInfoEx(Arguments args, Column col) + /// + /// Describes how the transformer handles one column pair. + /// + /// Name of input column. + /// Name of output column. + /// Number of bits to hash into. Must be between 1 and 31, inclusive. + /// Hashing seed. + /// Whether the position of each term should be included in the hash. + /// Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit. + public ColumnInfo(string input, string output, + int hashBits = HashEstimator.Defaults.HashBits, + uint seed = HashEstimator.Defaults.Seed, + bool ordered = HashEstimator.Defaults.Ordered, + int invertHash = HashEstimator.Defaults.InvertHash) { - HashBits = col.HashBits ?? args.HashBits; - if (HashBits < NumBitsMin || HashBits >= NumBitsLim) - throw Contracts.ExceptUserArg(nameof(args.HashBits), "Should be between {0} and {1} inclusive", NumBitsMin, NumBitsLim - 1); - HashSeed = col.Seed ?? args.Seed; - Ordered = col.Ordered ?? args.Ordered; + Input = input; + Output = output; + HashBits = hashBits; + Seed = seed; + Ordered = ordered; + InvertHash = invertHash; } - public ColInfoEx(ModelLoadContext ctx) + internal ColumnInfo(string input, string output, ModelLoadContext ctx) { + Input = input; + Output = output; // *** Binary format *** // int: HashBits // uint: HashSeed // byte: Ordered - HashBits = ctx.Reader.ReadInt32(); - Contracts.CheckDecode(NumBitsMin <= HashBits && HashBits < NumBitsLim); - - HashSeed = ctx.Reader.ReadUInt32(); + Contracts.CheckDecode(HashEstimator.NumBitsMin <= HashBits && HashBits < HashEstimator.NumBitsLim); + Seed = ctx.Reader.ReadUInt32(); Ordered = ctx.Reader.ReadBoolByte(); } - public void Save(ModelSaveContext ctx) + internal void Save(ModelSaveContext ctx) { // *** Binary format *** // int: HashBits // uint: HashSeed // byte: Ordered - Contracts.Assert(NumBitsMin <= HashBits && HashBits < NumBitsLim); + Contracts.Assert(HashEstimator.NumBitsMin <= HashBits && HashBits < HashEstimator.NumBitsLim); ctx.Writer.Write(HashBits); - ctx.Writer.Write(HashSeed); + ctx.Writer.Write(Seed); ctx.Writer.WriteBoolByte(Ordered); } } - private static string TestType(ColumnType type) + public static string TestType(ColumnType type) { if (type.ItemType.IsText || type.ItemType.IsKey || type.ItemType == NumberType.R4 || type.ItemType == NumberType.R8) return null; @@ -188,109 +199,66 @@ private static VersionInfo GetVersionInfo() loaderSignature: LoaderSignature); } - private readonly ColInfoEx[] _exes; - private readonly ColumnType[] _types; - + private readonly ColumnInfo[] _columns; private readonly VBuffer[] _keyValues; private readonly ColumnType[] _kvTypes; - public static HashTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) { - Contracts.CheckValue(env, nameof(env)); - var h = env.Register(RegistrationName); - h.CheckValue(ctx, nameof(ctx)); - h.CheckValue(input, nameof(input)); - ctx.CheckAtModel(GetVersionInfo()); - return h.Apply("Loading Model", ch => new HashTransform(h, ctx, input)); + var type = inputSchema.GetColumnType(srcCol); + string reason = TestType(type); + if (reason != null) + throw Host.ExceptParam(nameof(inputSchema), reason); } - - private HashTransform(IHost host, ModelLoadContext ctx, IDataView input) - : base(host, ctx, input, TestType) + private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) { - Host.AssertValue(ctx); - - // *** Binary format *** - // - // Exes - - Host.AssertNonEmpty(Infos); - _exes = new ColInfoEx[Infos.Length]; - for (int iinfo = 0; iinfo < Infos.Length; iinfo++) - _exes[iinfo] = new ColInfoEx(ctx); - - _types = InitColumnTypes(); - - TextModelHelper.LoadAll(Host, ctx, Infos.Length, out _keyValues, out _kvTypes); - SetMetadata(); + Contracts.CheckValue(columns, nameof(columns)); + return columns.Select(x => (x.Input, x.Output)).ToArray(); } - public override void Save(ModelSaveContext ctx) + internal ColumnType GetOutputType(ISchema inputSchema, ColumnInfo column) { - Host.CheckValue(ctx, nameof(ctx)); - ctx.CheckAtModel(); - ctx.SetVersionInfo(GetVersionInfo()); - - // - // - // Exes - - SaveBase(ctx); - Host.Assert(_exes.Length == Infos.Length); - for (int iinfo = 0; iinfo < Infos.Length; iinfo++) - _exes[iinfo].Save(ctx); - - TextModelHelper.SaveAll(Host, ctx, Infos.Length, _keyValues); + var keyCount = column.HashBits < 31 ? 1 << column.HashBits : 0; + inputSchema.TryGetColumnIndex(column.Input, out int srcCol); + var itemType = new KeyType(DataKind.U4, 0, keyCount, keyCount > 0); + var srcType = inputSchema.GetColumnType(srcCol); + if (!srcType.IsVector) + return itemType; + else + return new VectorType(itemType, srcType.VectorSize); } - /// - /// Convenience constructor for public facing API. - /// - /// Host Environment. - /// Input . This is the output from previous transform or loader. - /// Name of the output column. - /// Name of the column to be transformed. If this is null '' will be used. - /// Number of bits to hash into. Must be between 1 and 31, inclusive. - /// Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit. - public HashTransform(IHostEnvironment env, - IDataView input, - string name, - string source = null, - int hashBits = Defaults.HashBits, - int invertHash = Defaults.InvertHash) - : this(env, new Arguments() { - Column = new[] { new Column() { Source = source ?? name, Name = name } }, - HashBits = hashBits, InvertHash = invertHash }, input) + public HashTransform(IHostEnvironment env, IDataView input, ColumnInfo[] columns) : + base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns)) { - } - - public HashTransform(IHostEnvironment env, Arguments args, IDataView input) - : base(Contracts.CheckRef(env, nameof(env)), RegistrationName, env.CheckRef(args, nameof(args)).Column, - input, TestType) - { - if (args.HashBits < NumBitsMin || args.HashBits >= NumBitsLim) - throw Host.ExceptUserArg(nameof(args.HashBits), "hashBits should be between {0} and {1} inclusive", NumBitsMin, NumBitsLim - 1); - - _exes = new ColInfoEx[Infos.Length]; + _columns = columns.ToArray(); + //IVAN: Validate input schema + var types = new ColumnType[_columns.Length]; List invertIinfos = null; List invertHashMaxCounts = null; - for (int iinfo = 0; iinfo < Infos.Length; iinfo++) + HashSet sourceColumnsForInvertHash = new HashSet(); + for (int i = 0; i < _columns.Length; i++) { - _exes[iinfo] = new ColInfoEx(args, args.Column[iinfo]); - int invertHashMaxCount = GetAndVerifyInvertHashMaxCount(args, args.Column[iinfo], _exes[iinfo]); + if (!input.Schema.TryGetColumnIndex(ColumnPairs[i].input, out int srcCol)) + throw Host.ExceptSchemaMismatch(nameof(input), "input", ColumnPairs[i].input); + CheckInputColumn(input.Schema, i, srcCol); + + types[i] = GetOutputType(input.Schema, _columns[i]); + int invertHashMaxCount; + if (_columns[i].InvertHash == -1) + invertHashMaxCount = int.MaxValue; + else + invertHashMaxCount = _columns[i].InvertHash; if (invertHashMaxCount > 0) { - Utils.Add(ref invertIinfos, iinfo); + Utils.Add(ref invertIinfos, i); Utils.Add(ref invertHashMaxCounts, invertHashMaxCount); + sourceColumnsForInvertHash.Add(srcCol); } } - - _types = InitColumnTypes(); - - if (Utils.Size(invertIinfos) > 0) + if (Utils.Size(sourceColumnsForInvertHash) > 0) { - // Build the invert hashes for all columns for which it was requested. - var srcs = new HashSet(invertIinfos.Select(i => Infos[i].Source)); - using (IRowCursor srcCursor = input.GetRowCursor(srcs.Contains)) + using (IRowCursor srcCursor = input.GetRowCursor(sourceColumnsForInvertHash.Contains)) { using (var ch = Host.Start("Invert hash building")) { @@ -299,160 +267,152 @@ public HashTransform(IHostEnvironment env, Arguments args, IDataView input) for (int i = 0; i < helpers.Length; ++i) { int iinfo = invertIinfos[i]; - Host.Assert(_types[iinfo].ItemType.KeyCount > 0); - var dstGetter = GetGetterCore(ch, srcCursor, iinfo, out disposer); + Host.Assert(types[iinfo].ItemType.KeyCount > 0); + var dstGetter = GetGetterCore(srcCursor, iinfo, out disposer); Host.Assert(disposer == null); - var ex = _exes[iinfo]; + var ex = _columns[iinfo]; var maxCount = invertHashMaxCounts[i]; - helpers[i] = InvertHashHelper.Create(srcCursor, Infos[iinfo], ex, maxCount, dstGetter); + helpers[i] = InvertHashHelper.Create(srcCursor, ex, maxCount, dstGetter); } while (srcCursor.MoveNext()) { for (int i = 0; i < helpers.Length; ++i) helpers[i].Process(); } - _keyValues = new VBuffer[_exes.Length]; - _kvTypes = new ColumnType[_exes.Length]; + _keyValues = new VBuffer[_columns.Length]; + _kvTypes = new ColumnType[_columns.Length]; for (int i = 0; i < helpers.Length; ++i) { _keyValues[invertIinfos[i]] = helpers[i].GetKeyValuesMetadata(); - Host.Assert(_keyValues[invertIinfos[i]].Length == _types[invertIinfos[i]].ItemType.KeyCount); + Host.Assert(_keyValues[invertIinfos[i]].Length == types[invertIinfos[i]].ItemType.KeyCount); _kvTypes[invertIinfos[i]] = new VectorType(TextType.Instance, _keyValues[invertIinfos[i]].Length); } ch.Done(); } } } - SetMetadata(); } - /// - /// Re-apply constructor. - /// - private HashTransform(IHostEnvironment env, HashTransform transform, IDataView newSource) - : base(env, RegistrationName, transform, newSource, TestType) + internal Delegate GetGetterCore(IRow input, int iinfo, out Action disposer) { - _exes = transform._exes; - _types = InitColumnTypes(); - _keyValues = transform._keyValues; - _kvTypes = transform._kvTypes; - SetMetadata(); + Host.AssertValue(input); + Host.Assert(0 <= iinfo && iinfo < _columns.Length); + disposer = null; + input.Schema.TryGetColumnIndex(_columns[iinfo].Input, out int srcCol); + var srcType = input.Schema.GetColumnType(srcCol); + if (!srcType.IsVector) + return ComposeGetterOne(input, iinfo, srcCol, srcType); + return ComposeGetterVec(input, iinfo, srcCol, srcType); } - public IDataTransform ApplyToData(IHostEnvironment env, IDataView newSource) - { - return new HashTransform(env, this, newSource); - } + protected override IRowMapper MakeRowMapper(ISchema schema) => new Mapper(this, schema); - private static int GetAndVerifyInvertHashMaxCount(Arguments args, Column col, ColInfoEx ex) + // Factory method for SignatureLoadModel. + private static HashTransform Create(IHostEnvironment env, ModelLoadContext ctx) { - var invertHashMaxCount = col.InvertHash ?? args.InvertHash; - if (invertHashMaxCount != 0) - { - if (invertHashMaxCount == -1) - invertHashMaxCount = int.MaxValue; - Contracts.CheckUserArg(invertHashMaxCount > 0, nameof(args.InvertHash), "Value too small, must be -1 or larger"); - // If the bits is 31 or higher, we can't declare a KeyValues of the appropriate length, - // this requiring a VBuffer of length 1u << 31 which exceeds int.MaxValue. - if (ex.HashBits >= 31) - throw Contracts.ExceptUserArg(nameof(args.InvertHash), "Cannot support invertHash for a {0} bit hash. 30 is the maximum possible.", ex.HashBits); - } - return invertHashMaxCount; - } + Contracts.CheckValue(env, nameof(env)); + var host = env.Register(RegistrationName); - private ColumnType[] InitColumnTypes() - { - var types = new ColumnType[Infos.Length]; - for (int iinfo = 0; iinfo < Infos.Length; iinfo++) - { - var keyCount = _exes[iinfo].HashBits < 31 ? 1 << _exes[iinfo].HashBits : 0; - var itemType = new KeyType(DataKind.U4, 0, keyCount, keyCount > 0); - if (!Infos[iinfo].TypeSrc.IsVector) - types[iinfo] = itemType; - else - types[iinfo] = new VectorType(itemType, Infos[iinfo].TypeSrc.VectorSize); - } - return types; - } + host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); - protected override ColumnType GetColumnTypeCore(int iinfo) - { - Host.Check(0 <= iinfo & iinfo < Infos.Length); - return _types[iinfo]; + return new HashTransform(host, ctx); } - private void SetMetadata() + private HashTransform(IHost host, ModelLoadContext ctx) + : base(host, ctx) { - var md = Metadata; - for (int iinfo = 0; iinfo < Infos.Length; iinfo++) - { - using (var bldr = md.BuildMetadata(iinfo, Source.Schema, Infos[iinfo].Source, - MetadataUtils.Kinds.SlotNames)) - { - if (_kvTypes != null && _kvTypes[iinfo] != null) - bldr.AddGetter>(MetadataUtils.Kinds.KeyValues, _kvTypes[iinfo], GetTerms); - } - } - md.Seal(); + var columnsLength = ColumnPairs.Length; + _columns = new ColumnInfo[columnsLength]; + for (int i = 0; i < columnsLength; i++) + _columns[i] = new ColumnInfo(ColumnPairs[i].input, ColumnPairs[i].output, ctx); + TextModelHelper.LoadAll(Host, ctx, columnsLength, out _keyValues, out _kvTypes); } - private void GetTerms(int iinfo, ref VBuffer dst) + public override void Save(ModelSaveContext ctx) { - Host.Assert(0 <= iinfo && iinfo < Infos.Length); - Host.Assert(Utils.Size(_keyValues) == Infos.Length); - Host.Assert(_keyValues[iinfo].Length > 0); - _keyValues[iinfo].CopyTo(ref dst); + Host.CheckValue(ctx, nameof(ctx)); + + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); + + SaveColumns(ctx); + + // + // + // Exes + Host.Assert(_columns.Length == ColumnPairs.Length); + foreach (var col in _columns) + col.Save(ctx); + + TextModelHelper.SaveAll(Host, ctx, _columns.Length, _keyValues); } - protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) + // Factory method for SignatureLoadDataTransform. + private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + => Create(env, ctx).MakeDataTransform(input); + + // Factory method for SignatureLoadRowMapper. + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); + + // Factory method for SignatureDataTransform. + public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) { - Host.AssertValueOrNull(ch); - Host.AssertValue(input); - Host.Assert(0 <= iinfo && iinfo < Infos.Length); - disposer = null; + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(args, nameof(args)); + env.CheckValue(input, nameof(input)); - if (!Infos[iinfo].TypeSrc.IsVector) - return ComposeGetterOne(input, iinfo); - return ComposeGetterVec(input, iinfo); + env.CheckValue(args.Column, nameof(args.Column)); + var cols = new ColumnInfo[args.Column.Length]; + for (int i = 0; i < cols.Length; i++) + { + var item = args.Column[i]; + var kind = item.InvertHash ?? args.InvertHash; + cols[i] = new ColumnInfo(item.Source, + item.Name, + item.HashBits ?? args.HashBits, + item.Seed ?? args.Seed, + item.Ordered ?? args.Ordered, + item.InvertHash ?? args.InvertHash); + }; + return new HashTransform(env, input, cols).MakeDataTransform(input); } - /// - /// Getter generator for single valued inputs - /// - private ValueGetter ComposeGetterOne(IRow input, int iinfo) + #region Getters + private ValueGetter ComposeGetterOne(IRow input, int iinfo, int srcCol, ColumnType srcType) { - var colType = Infos[iinfo].TypeSrc; - Host.Assert(colType.IsText || colType.IsKey || colType == NumberType.R4 || colType == NumberType.R8); + Host.Assert(srcType.IsText || srcType.IsKey || srcType == NumberType.R4 || srcType == NumberType.R8); - var mask = (1U << _exes[iinfo].HashBits) - 1; - uint seed = _exes[iinfo].HashSeed; + var mask = (1U << _columns[iinfo].HashBits) - 1; + uint seed = _columns[iinfo].Seed; // In case of single valued input column, hash in 0 for the slot index. - if (_exes[iinfo].Ordered) + if (_columns[iinfo].Ordered) seed = Hashing.MurmurRound(seed, 0); - switch (colType.RawKind) + switch (srcType.RawKind) { - case DataKind.Text: - return ComposeGetterOneCore(GetSrcGetter(input, iinfo), seed, mask); - case DataKind.U1: - return ComposeGetterOneCore(GetSrcGetter(input, iinfo), seed, mask); - case DataKind.U2: - return ComposeGetterOneCore(GetSrcGetter(input, iinfo), seed, mask); - case DataKind.U4: - return ComposeGetterOneCore(GetSrcGetter(input, iinfo), seed, mask); - case DataKind.R4: - return ComposeGetterOneCore(GetSrcGetter(input, iinfo), seed, mask); - case DataKind.R8: - return ComposeGetterOneCore(GetSrcGetter(input, iinfo), seed, mask); - default: - Host.Assert(colType.RawKind == DataKind.U8); - return ComposeGetterOneCore(GetSrcGetter(input, iinfo), seed, mask); + case DataKind.Text: + return ComposeGetterOneCore(input.GetGetter(srcCol), seed, mask); + case DataKind.U1: + return ComposeGetterOneCore(input.GetGetter(srcCol), seed, mask); + case DataKind.U2: + return ComposeGetterOneCore(input.GetGetter(srcCol), seed, mask); + case DataKind.U4: + return ComposeGetterOneCore(input.GetGetter(srcCol), seed, mask); + case DataKind.R4: + return ComposeGetterOneCore(input.GetGetter(srcCol), seed, mask); + case DataKind.R8: + return ComposeGetterOneCore(input.GetGetter(srcCol), seed, mask); + default: + Host.Assert(srcType.RawKind == DataKind.U8); + return ComposeGetterOneCore(input.GetGetter(srcCol), seed, mask); } } private ValueGetter ComposeGetterOneCore(ValueGetter getSrc, uint seed, uint mask) { - DvText src = default(DvText); + DvText src = default; return (ref uint dst) => { @@ -537,43 +497,42 @@ private ValueGetter ComposeGetterOneCore(ValueGetter getSrc, uint // them (either with their index or without) into dst. Additionally it fills in zero hashes in the rest of dst elements. private delegate void HashLoopWithZeroHash(int count, int[] indices, TSrc[] src, uint[] dst, int dstCount, uint seed, uint mask); - private ValueGetter> ComposeGetterVec(IRow input, int iinfo) + private ValueGetter> ComposeGetterVec(IRow input, int iinfo, int srcCol, ColumnType srcType) { - var colType = Infos[iinfo].TypeSrc; - Host.Assert(colType.IsVector); - Host.Assert(colType.ItemType.IsText || colType.ItemType.IsKey || colType.ItemType == NumberType.R4 || colType.ItemType == NumberType.R8); + Host.Assert(srcType.IsVector); + Host.Assert(srcType.ItemType.IsText || srcType.ItemType.IsKey || srcType.ItemType == NumberType.R4 || srcType.ItemType == NumberType.R8); - switch (colType.ItemType.RawKind) + switch (srcType.ItemType.RawKind) { - case DataKind.Text: - return ComposeGetterVecCore(input, iinfo, HashUnord, HashDense, HashSparse); - case DataKind.U1: - return ComposeGetterVecCore(input, iinfo, HashUnord, HashDense, HashSparse); - case DataKind.U2: - return ComposeGetterVecCore(input, iinfo, HashUnord, HashDense, HashSparse); - case DataKind.U4: - return ComposeGetterVecCore(input, iinfo, HashUnord, HashDense, HashSparse); - case DataKind.R4: - return ComposeGetterVecCoreFloat(input, iinfo, HashSparseUnord, HashUnord, HashDense); - case DataKind.R8: - return ComposeGetterVecCoreFloat(input, iinfo, HashSparseUnord, HashUnord, HashDense); - default: - Host.Assert(colType.ItemType.RawKind == DataKind.U8); - return ComposeGetterVecCore(input, iinfo, HashUnord, HashDense, HashSparse); + case DataKind.Text: + return ComposeGetterVecCore(input, iinfo, srcCol, srcType, HashUnord, HashDense, HashSparse); + case DataKind.U1: + return ComposeGetterVecCore(input, iinfo, srcCol, srcType, HashUnord, HashDense, HashSparse); + case DataKind.U2: + return ComposeGetterVecCore(input, iinfo, srcCol, srcType, HashUnord, HashDense, HashSparse); + case DataKind.U4: + return ComposeGetterVecCore(input, iinfo, srcCol, srcType, HashUnord, HashDense, HashSparse); + case DataKind.R4: + return ComposeGetterVecCoreFloat(input, iinfo, srcCol, srcType, HashSparseUnord, HashUnord, HashDense); + case DataKind.R8: + return ComposeGetterVecCoreFloat(input, iinfo, srcCol, srcType, HashSparseUnord, HashUnord, HashDense); + default: + Host.Assert(srcType.ItemType.RawKind == DataKind.U8); + return ComposeGetterVecCore(input, iinfo, srcCol, srcType, HashUnord, HashDense, HashSparse); } } - private ValueGetter> ComposeGetterVecCore(IRow input, int iinfo, + private ValueGetter> ComposeGetterVecCore(IRow input, int iinfo, int srcCol, ColumnType srcType, HashLoop hasherUnord, HashLoop hasherDense, HashLoop hasherSparse) { - Host.Assert(Infos[iinfo].TypeSrc.IsVector); - Host.Assert(Infos[iinfo].TypeSrc.ItemType.RawType == typeof(T)); + Host.Assert(srcType.IsVector); + Host.Assert(srcType.ItemType.RawType == typeof(T)); - var getSrc = GetSrcGetter>(input, iinfo); - var ex = _exes[iinfo]; + var getSrc = input.GetGetter>(srcCol); + var ex = _columns[iinfo]; var mask = (1U << ex.HashBits) - 1; - var seed = ex.HashSeed; - var len = Infos[iinfo].TypeSrc.VectorSize; + var seed = ex.Seed; + var len = srcType.VectorSize; var src = default(VBuffer); if (!ex.Ordered) @@ -612,20 +571,20 @@ private ValueGetter> ComposeGetterVecCore(IRow input, int iinfo }; } - private ValueGetter> ComposeGetterVecCoreFloat(IRow input, int iinfo, + private ValueGetter> ComposeGetterVecCoreFloat(IRow input, int iinfo, int srcCol, ColumnType srcType, HashLoopWithZeroHash hasherSparseUnord, HashLoop hasherDenseUnord, HashLoop hasherDenseOrdered) { - Host.Assert(Infos[iinfo].TypeSrc.IsVector); - Host.Assert(Infos[iinfo].TypeSrc.ItemType.RawType == typeof(T)); + Host.Assert(srcType.IsVector); + Host.Assert(srcType.ItemType.RawType == typeof(T)); - var getSrc = GetSrcGetter>(input, iinfo); - var ex = _exes[iinfo]; + var getSrc = input.GetGetter>(srcCol); + var ex = _columns[iinfo]; var mask = (1U << ex.HashBits) - 1; - var seed = ex.HashSeed; - var len = Infos[iinfo].TypeSrc.VectorSize; + var seed = ex.Seed; + var len = srcType.VectorSize; var src = default(VBuffer); T[] denseValues = null; - int expectedSrcLength = Infos[iinfo].TypeSrc.VectorSize; + int expectedSrcLength = srcType.VectorSize; HashLoop hasherDense = ex.Ordered ? hasherDenseOrdered : hasherDenseUnord; return @@ -668,6 +627,8 @@ private ValueGetter> ComposeGetterVecCoreFloat(IRow input, int }; } + #endregion + #region Core Hash functions, with and without index [MethodImpl(MethodImplOptions.AggressiveInlining)] private static uint HashCore(uint seed, ref DvText value, uint mask) @@ -704,7 +665,7 @@ private static uint HashCore(uint seed, ref float value, int i, uint mask) if (value.IsNA()) return 0; return (Hashing.MixHash(Hashing.MurmurRound(Hashing.MurmurRound(seed, (uint)i), - FloatUtils.GetBits(value == 0 ? 0: value))) & mask) + 1; + FloatUtils.GetBits(value == 0 ? 0 : value))) & mask) + 1; } [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -787,9 +748,9 @@ private static uint HashCore(uint seed, ulong value, int i, uint mask) hash = Hashing.MurmurRound(hash, hi); return (Hashing.MixHash(hash) & mask) + 1; } -#endregion Core Hash functions, with and without index + #endregion Core Hash functions, with and without index -#region Unordered Loop: ignore indices + #region Unordered Loop: ignore indices private static void HashUnord(int count, int[] indices, DvText[] src, uint[] dst, uint seed, uint mask) { AssertValid(count, src, dst); @@ -847,7 +808,7 @@ private static void HashUnord(int count, int[] indices, double[] src, uint[] dst #endregion Unordered Loop: ignore indices -#region Dense Loop: ignore indices + #region Dense Loop: ignore indices private static void HashDense(int count, int[] indices, DvText[] src, uint[] dst, uint seed, uint mask) { AssertValid(count, src, dst); @@ -902,9 +863,9 @@ private static void HashDense(int count, int[] indices, double[] src, uint[] dst for (int i = 0; i < count; i++) dst[i] = HashCore(seed, ref src[i], i, mask); } -#endregion Dense Loop: ignore indices + #endregion Dense Loop: ignore indices -#region Sparse Loop: use indices + #region Sparse Loop: use indices private static void HashSparse(int count, int[] indices, DvText[] src, uint[] dst, uint seed, uint mask) { AssertValid(count, src, dst); @@ -992,7 +953,7 @@ private static void HashSparseUnord(int count, int[] indices, double[] src, uint } } -#endregion Sparse Loop: use indices + #endregion Sparse Loop: use indices [Conditional("DEBUG")] private static void AssertValid(int count, T[] src, uint[] dst) @@ -1002,27 +963,94 @@ private static void AssertValid(int count, T[] src, uint[] dst) Contracts.Assert(count <= Utils.Size(dst)); } - /// - /// This is a utility class to acquire and build the inverse hash to populate - /// KeyValues metadata. - /// + private sealed class Mapper : MapperBase + { + private sealed class ColInfo + { + public readonly string Name; + public readonly string Source; + public readonly ColumnType TypeSrc; + + public ColInfo(string name, string source, ColumnType type) + { + Name = name; + Source = source; + TypeSrc = type; + } + } + + private readonly ColumnType[] _types; + private readonly HashTransform _parent; + + public Mapper(HashTransform parent, ISchema inputSchema) + : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) + { + _parent = parent; + _types = new ColumnType[_parent._columns.Length]; + for (int i = 0; i < _types.Length; i++) + _types[i] = _parent.GetOutputType(inputSchema, _parent._columns[i]); + } + + public override RowMapperColumnInfo[] GetOutputColumns() + { + var result = new RowMapperColumnInfo[_parent.ColumnPairs.Length]; + for (int i = 0; i < _parent.ColumnPairs.Length; i++) + { + InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colIndex); + var colMetaInfo = new ColumnMetadataInfo(_parent.ColumnPairs[i].output); + + foreach (var type in InputSchema.GetMetadataTypes(colIndex).Where(x => x.Key == MetadataUtils.Kinds.SlotNames)) + Utils.MarshalInvoke(AddMetaGetter, type.Value.RawType, colMetaInfo, InputSchema, type.Key, type.Value, colIndex); + if (_parent._kvTypes != null && _parent._kvTypes[i] != null) + AddMetaKeyValues(i, colMetaInfo); + result[i] = new RowMapperColumnInfo(_parent.ColumnPairs[i].output, _types[i], colMetaInfo); + } + return result; + } + private void AddMetaKeyValues(int i, ColumnMetadataInfo colMetaInfo) + { + MetadataUtils.MetadataGetter> getter = (int col, ref VBuffer dst) => + { + _parent._keyValues[i].CopyTo(ref dst); + }; + var info = new MetadataInfo>(_parent._kvTypes[i], getter); + colMetaInfo.Add(MetadataUtils.Kinds.KeyValues, info); + } + + private int AddMetaGetter(ColumnMetadataInfo colMetaInfo, ISchema schema, string kind, ColumnType ct, int originalCol) + { + MetadataUtils.MetadataGetter getter = (int col, ref T dst) => + { + // We don't care about 'col': this getter is specialized for a column 'originalCol', + // and 'col' in this case is the 'metadata kind index', not the column index. + schema.GetMetadata(kind, originalCol, ref dst); + }; + var info = new MetadataInfo(ct, getter); + colMetaInfo.Add(kind, info); + return 0; + } + + protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer) => _parent.GetGetterCore(input, iinfo, out disposer); + } + private abstract class InvertHashHelper { protected readonly IRow Row; private readonly bool _includeSlot; - private readonly ColInfo _info; - private readonly ColInfoEx _ex; + private readonly ColumnInfo _ex; + private readonly ColumnType _srcType; + private readonly int _srcCol; - private InvertHashHelper(IRow row, ColInfo info, ColInfoEx ex) + private InvertHashHelper(IRow row, ColumnInfo ex) { Contracts.AssertValue(row); - Contracts.AssertValue(info); - Row = row; - _info = info; + row.Schema.TryGetColumnIndex(ex.Input, out int srcCol); + _srcCol = srcCol; + _srcType = row.Schema.GetColumnType(srcCol); _ex = ex; // If this is a vector and ordered, then we must include the slot as part of the representation. - _includeSlot = _info.TypeSrc.IsVector && _ex.Ordered; + _includeSlot = _srcType.IsVector && _ex.Ordered; } /// @@ -1031,18 +1059,18 @@ private InvertHashHelper(IRow row, ColInfo info, ColInfoEx ex) /// the row. /// /// The input source row, from which the hashed values can be fetched - /// The column info, describing the source /// The extra column info - /// The number of input hashed values to accumulate per output hash value + /// The number of input hashed valuPres to accumulate per output hash value /// A hash getter, built on top of . - public static InvertHashHelper Create(IRow row, ColInfo info, ColInfoEx ex, int invertHashMaxCount, Delegate dstGetter) + public static InvertHashHelper Create(IRow row, ColumnInfo ex, int invertHashMaxCount, Delegate dstGetter) { - ColumnType typeSrc = info.TypeSrc; + row.Schema.TryGetColumnIndex(ex.Input, out int srcCol); + ColumnType typeSrc = row.Schema.GetColumnType(srcCol); Type t = typeSrc.IsVector ? (ex.Ordered ? typeof(ImplVecOrdered<>) : typeof(ImplVec<>)) : typeof(ImplOne<>); t = t.MakeGenericType(typeSrc.ItemType.RawType); - var consTypes = new Type[] { typeof(IRow), typeof(OneToOneTransformBase.ColInfo), typeof(ColInfoEx), typeof(int), typeof(Delegate) }; + var consTypes = new Type[] { typeof(IRow), typeof(ColumnInfo), typeof(int), typeof(Delegate) }; var constructorInfo = t.GetConstructor(consTypes); - return (InvertHashHelper)constructorInfo.Invoke(new object[] { row, info, ex, invertHashMaxCount, dstGetter }); + return (InvertHashHelper)constructorInfo.Invoke(new object[] { row, ex, invertHashMaxCount, dstGetter }); } /// @@ -1098,7 +1126,7 @@ public int GetHashCode(KeyValuePair obj) private IEqualityComparer GetSimpleComparer() { - Contracts.Assert(_info.TypeSrc.ItemType.RawType == typeof(T)); + Contracts.Assert(_srcType.ItemType.RawType == typeof(T)); if (typeof(T) == typeof(DvText)) { // We are hashing twice, once to assign to the slot, and then again, @@ -1106,7 +1134,7 @@ private IEqualityComparer GetSimpleComparer() // same seed used to assign to a slot, or otherwise this per-slot hash // would have a lot of collisions. We ensure that we have different // hash function by inverting the seed's bits. - var c = new TextEqualityComparer(~_ex.HashSeed); + var c = new TextEqualityComparer(~_ex.Seed); return c as IEqualityComparer; } // I assume I hope correctly that the default .NET hash function for uint @@ -1120,11 +1148,10 @@ private abstract class Impl : InvertHashHelper { protected readonly InvertHashCollector Collector; - protected Impl(IRow row, ColInfo info, ColInfoEx ex, int invertHashMaxCount) - : base(row, info, ex) + protected Impl(IRow row, ColumnInfo ex, int invertHashMaxCount) + : base(row, ex) { Contracts.AssertValue(row); - Contracts.AssertValue(info); Contracts.AssertValue(ex); Collector = new InvertHashCollector(1 << ex.HashBits, invertHashMaxCount, GetTextMap(), GetComparer()); @@ -1132,7 +1159,7 @@ protected Impl(IRow row, ColInfo info, ColInfoEx ex, int invertHashMaxCount) protected virtual ValueMapper GetTextMap() { - return InvertHashUtils.GetSimpleMapper(Row.Schema, _info.Source); + return InvertHashUtils.GetSimpleMapper(Row.Schema, _srcCol); } protected virtual IEqualityComparer GetComparer() @@ -1154,10 +1181,10 @@ private sealed class ImplOne : Impl private T _value; private uint _hash; - public ImplOne(IRow row, OneToOneTransformBase.ColInfo info, ColInfoEx ex, int invertHashMaxCount, Delegate dstGetter) - : base(row, info, ex, invertHashMaxCount) + public ImplOne(IRow row, ColumnInfo ex, int invertHashMaxCount, Delegate dstGetter) + : base(row, ex, invertHashMaxCount) { - _srcGetter = Row.GetGetter(_info.Source); + _srcGetter = Row.GetGetter(_srcCol); _dstGetter = dstGetter as ValueGetter; Contracts.AssertValue(_dstGetter); } @@ -1188,10 +1215,10 @@ private sealed class ImplVec : Impl private VBuffer _value; private VBuffer _hash; - public ImplVec(IRow row, OneToOneTransformBase.ColInfo info, ColInfoEx ex, int invertHashMaxCount, Delegate dstGetter) - : base(row, info, ex, invertHashMaxCount) + public ImplVec(IRow row, ColumnInfo ex, int invertHashMaxCount, Delegate dstGetter) + : base(row, ex, invertHashMaxCount) { - _srcGetter = Row.GetGetter>(_info.Source); + _srcGetter = Row.GetGetter>(_srcCol); _dstGetter = dstGetter as ValueGetter>; Contracts.AssertValue(_dstGetter); } @@ -1217,17 +1244,17 @@ private sealed class ImplVecOrdered : Impl> private VBuffer _value; private VBuffer _hash; - public ImplVecOrdered(IRow row, OneToOneTransformBase.ColInfo info, ColInfoEx ex, int invertHashMaxCount, Delegate dstGetter) - : base(row, info, ex, invertHashMaxCount) + public ImplVecOrdered(IRow row, ColumnInfo ex, int invertHashMaxCount, Delegate dstGetter) + : base(row, ex, invertHashMaxCount) { - _srcGetter = Row.GetGetter>(_info.Source); + _srcGetter = Row.GetGetter>(_srcCol); _dstGetter = dstGetter as ValueGetter>; Contracts.AssertValue(_dstGetter); } protected override ValueMapper, StringBuilder> GetTextMap() { - var simple = InvertHashUtils.GetSimpleMapper(Row.Schema, _info.Source); + var simple = InvertHashUtils.GetSimpleMapper(Row.Schema, _srcCol); return InvertHashUtils.GetPairMapper(simple); } @@ -1258,5 +1285,65 @@ public override void Process() } } } + + public sealed class HashEstimator : IEstimator + { + public const int NumBitsMin = 1; + public const int NumBitsLim = 32; + + public static class Defaults + { + public const int HashBits = NumBitsLim - 1; + public const uint Seed = 314489979; + public const bool Ordered = false; + public const int InvertHash = 0; + } + + private readonly IHost _host; + private readonly HashTransform.ColumnInfo[] _columns; + + public HashEstimator(IHostEnvironment env, string name, string source = null, + int hashBits = Defaults.HashBits, int invertHash = Defaults.InvertHash) + : this(env, new HashTransform.ColumnInfo(name, source ?? name, hashBits: hashBits, invertHash: invertHash)) + { + } + + public HashEstimator(IHostEnvironment env, params HashTransform.ColumnInfo[] columns) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(nameof(HashEstimator)); + _columns = columns.ToArray(); + foreach (var col in _columns) + { + if (col.InvertHash < -1) + throw _host.ExceptParam(nameof(columns), "Value too small, must be -1 or larger"); + if (col.InvertHash != 0 && col.HashBits >= 31) + throw _host.ExceptParam(nameof(columns), $"Cannot support invertHash for a {0} bit hash. 30 is the maximum possible.", col.HashBits); + } + } + + public HashTransform Fit(IDataView input) => new HashTransform(_host, input, _columns); + + public SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + _host.CheckValue(inputSchema, nameof(inputSchema)); + var result = inputSchema.Columns.ToDictionary(x => x.Name); + foreach (var colInfo in _columns) + { + if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + string reason = HashTransform.TestType(col.ItemType); + if (reason != null) + throw _host.ExceptParam(nameof(inputSchema), reason); + var metadata = new List(); + if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.SlotNames, out var slotMeta)) + metadata.Add(slotMeta); + if (colInfo.InvertHash != 0) + metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false)); + result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, col.ItemType.IsVector ? SchemaShape.Column.VectorKind.Vector : SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true, new SchemaShape(metadata.ToArray())); + } + return new SchemaShape(result.Values); + } + } } diff --git a/src/Microsoft.ML.Transforms/CategoricalHashTransform.cs b/src/Microsoft.ML.Transforms/CategoricalHashTransform.cs index 582586b9a0..2f4c89eb67 100644 --- a/src/Microsoft.ML.Transforms/CategoricalHashTransform.cs +++ b/src/Microsoft.ML.Transforms/CategoricalHashTransform.cs @@ -198,7 +198,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV return CreateTransformCore( args.OutputKind, args.Column, args.Column.Select(col => col.OutputKind).ToList(), - new HashTransform(h, hashArgs, input), + HashTransform.Create(h, hashArgs, input), h, args); } diff --git a/src/Microsoft.ML.Transforms/Text/WordHashBagTransform.cs b/src/Microsoft.ML.Transforms/Text/WordHashBagTransform.cs index 9417ea6260..48d5a17c26 100644 --- a/src/Microsoft.ML.Transforms/Text/WordHashBagTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/WordHashBagTransform.cs @@ -444,7 +444,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV InvertHash = args.InvertHash }; - view = new HashTransform(h, hashArgs, view); + view = HashTransform.Create(h, hashArgs, view); // creating the NgramHash function var ngramHashArgs = diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs index 3abff3e560..67223aec2e 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs @@ -89,7 +89,7 @@ private void TestHashTransformHelper(T[] data, uint[] results, NumberType typ HashTransform.Arguments args = new HashTransform.Arguments(); args.Column = new HashTransform.Column[] { col }; - var hashTransform = new HashTransform(Env, args, srcView); + var hashTransform = HashTransform.Create(Env, args, srcView); using (var cursor = hashTransform.GetRowCursor(c => true)) { var resultGetter = cursor.GetGetter(1); @@ -127,7 +127,7 @@ private void TestHashTransformVectorHelper(ArrayDataViewBuilder builder, uint[][ HashTransform.Arguments args = new HashTransform.Arguments(); args.Column = new HashTransform.Column[] { col }; - var hashTransform = new HashTransform(Env, args, srcView); + var hashTransform = HashTransform.Create(Env, args, srcView); using (var cursor = hashTransform.GetRowCursor(c => true)) { var resultGetter = cursor.GetGetter>(1); diff --git a/test/Microsoft.ML.Tests/Transformers/HashTests.cs b/test/Microsoft.ML.Tests/Transformers/HashTests.cs new file mode 100644 index 0000000000..4f3b290ed3 --- /dev/null +++ b/test/Microsoft.ML.Tests/Transformers/HashTests.cs @@ -0,0 +1,59 @@ +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.RunTests; +using System; +using System.Collections.Generic; +using System.Text; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.Tests.Transformers +{ + public class HashTests : TestDataPipeBase + { + public HashTests(ITestOutputHelper output) : base(output) + { + } + + private class TestClass + { + public float A; + public float B; + public float C; + } + + /* private class TestMeta + { + [VectorType(2)] + public string[] A; + public string B; + [VectorType(2)] + public int[] C; + public int D; + [VectorType(2)] + public float[] E; + public float F; + [VectorType(2)] + public string[] G; + public string H; + }*/ + + [Fact] + public void HashWorkout() + { + var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; + + var dataView = ComponentCreation.CreateDataView(Env, data); + var pipe = new HashEstimator(Env, new[]{ + new HashTransform.ColumnInfo("A", "CatA", hashBits:4, invertHash:-1), + new HashTransform.ColumnInfo("B", "CatB", hashBits:3, ordered:true), + new HashTransform.ColumnInfo("C", "CatC", seed:42), + new HashTransform.ColumnInfo("A", "CatD"), + }); + + TestEstimatorCore(pipe, dataView); + Done(); + } + + } +} From 249a4e73a3b56eaea664b1ab29428bce7349720e Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Wed, 19 Sep 2018 10:06:15 -0700 Subject: [PATCH 2/8] small changes --- test/Microsoft.ML.Tests/Transformers/HashTests.cs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/test/Microsoft.ML.Tests/Transformers/HashTests.cs b/test/Microsoft.ML.Tests/Transformers/HashTests.cs index 4f3b290ed3..470b6c7b04 100644 --- a/test/Microsoft.ML.Tests/Transformers/HashTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/HashTests.cs @@ -1,9 +1,10 @@ -using Microsoft.ML.Runtime.Api; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.RunTests; -using System; -using System.Collections.Generic; -using System.Text; using Xunit; using Xunit.Abstractions; From 6bc51a7d402481ab41055c9b932dcc86c13b4573 Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Wed, 19 Sep 2018 10:41:59 -0700 Subject: [PATCH 3/8] fix tests --- .../Transforms/HashTransform.cs | 13 ++- .../DataPipe/TestDataPipe.cs | 2 +- .../Transformers/HashTests.cs | 102 +++++++++++++++--- 3 files changed, 93 insertions(+), 24 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/HashTransform.cs b/src/Microsoft.ML.Data/Transforms/HashTransform.cs index 61413b7473..bd944a062d 100644 --- a/src/Microsoft.ML.Data/Transforms/HashTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/HashTransform.cs @@ -2,17 +2,17 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Collections.Generic; -using System.Linq; -using System.Runtime.CompilerServices; -using System.Text; using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text; [assembly: LoadableClass(HashTransform.Summary, typeof(IDataTransform), typeof(HashTransform), typeof(HashTransform.Arguments), typeof(SignatureDataTransform), "Hash Transform", "HashTransform", "Hash", DocName = "transform/HashTransform.md")] @@ -232,7 +232,6 @@ public HashTransform(IHostEnvironment env, IDataView input, ColumnInfo[] columns base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns)) { _columns = columns.ToArray(); - //IVAN: Validate input schema var types = new ColumnType[_columns.Length]; List invertIinfos = null; List invertHashMaxCounts = null; @@ -369,7 +368,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV { var item = args.Column[i]; var kind = item.InvertHash ?? args.InvertHash; - cols[i] = new ColumnInfo(item.Source, + cols[i] = new ColumnInfo(item.Source ?? item.Name, item.Name, item.HashBits ?? args.HashBits, item.Seed ?? args.Seed, diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs index 67223aec2e..1d91e3b93c 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs @@ -83,7 +83,7 @@ private void TestHashTransformHelper(T[] data, uint[] results, NumberType typ var srcView = builder.GetDataView(); HashTransform.Column col = new HashTransform.Column(); - col.Source = "F1"; + col.Name = "F1"; col.HashBits = 5; col.Seed = 42; HashTransform.Arguments args = new HashTransform.Arguments(); diff --git a/test/Microsoft.ML.Tests/Transformers/HashTests.cs b/test/Microsoft.ML.Tests/Transformers/HashTests.cs index 470b6c7b04..d51f00d26d 100644 --- a/test/Microsoft.ML.Tests/Transformers/HashTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/HashTests.cs @@ -4,7 +4,11 @@ using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.RunTests; +using Microsoft.ML.Runtime.Tools; +using System.IO; +using System.Linq; using Xunit; using Xunit.Abstractions; @@ -23,21 +27,15 @@ private class TestClass public float C; } - /* private class TestMeta + private class TestMeta { [VectorType(2)] - public string[] A; - public string B; - [VectorType(2)] - public int[] C; - public int D; - [VectorType(2)] - public float[] E; - public float F; + public float[] A; + public float B; [VectorType(2)] - public string[] G; - public string H; - }*/ + public double[] C; + public double D; + } [Fact] public void HashWorkout() @@ -46,15 +44,87 @@ public void HashWorkout() var dataView = ComponentCreation.CreateDataView(Env, data); var pipe = new HashEstimator(Env, new[]{ - new HashTransform.ColumnInfo("A", "CatA", hashBits:4, invertHash:-1), - new HashTransform.ColumnInfo("B", "CatB", hashBits:3, ordered:true), - new HashTransform.ColumnInfo("C", "CatC", seed:42), - new HashTransform.ColumnInfo("A", "CatD"), + new HashTransform.ColumnInfo("A", "HashA", hashBits:4, invertHash:-1), + new HashTransform.ColumnInfo("B", "HashB", hashBits:3, ordered:true), + new HashTransform.ColumnInfo("C", "HashC", seed:42), + new HashTransform.ColumnInfo("A", "HashD"), }); TestEstimatorCore(pipe, dataView); Done(); } + [Fact] + public void TestMetadata() + { + + var data = new[] { + new TestMeta() { A=new float[2] { 3.5f, 2.5f}, B=1, C= new double[2] { 5.1f, 6.1f}, D= 7}, + new TestMeta() { A=new float[2] { 3.5f, 2.5f}, B=1, C= new double[2] { 5.1f, 6.1f}, D= 7}, + new TestMeta() { A=new float[2] { 3.5f, 2.5f}, B=1, C= new double[2] { 5.1f, 6.1f}, D= 7}}; + + + var dataView = ComponentCreation.CreateDataView(Env, data); + var pipe = new HashEstimator(Env, new[] { + new HashTransform.ColumnInfo("A", "HashA", invertHash:1, hashBits:10), + new HashTransform.ColumnInfo("A", "HashAUnlim", invertHash:-1, hashBits:10), + new HashTransform.ColumnInfo("A", "HashAUnlimOrdered", invertHash:-1, hashBits:10, ordered:true) + }); + var result = pipe.Fit(dataView).Transform(dataView); + ValidateMetadata(result); + Done(); + } + + private void ValidateMetadata(IDataView result) + { + + Assert.True(result.Schema.TryGetColumnIndex("HashA", out int HashA)); + Assert.True(result.Schema.TryGetColumnIndex("HashAUnlim", out int HashAUnlim)); + Assert.True(result.Schema.TryGetColumnIndex("HashAUnlimOrdered", out int HashAUnlimOrdered)); + VBuffer keys = default; + var types = result.Schema.GetMetadataTypes(HashA); + Assert.Equal(types.Select(x => x.Key), new string[1] { MetadataUtils.Kinds.KeyValues }); + result.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, HashA, ref keys); + Assert.True(keys.Length == 1024); + Assert.Equal(keys.Items().Select(x => x.Value.ToString()), new string[2] {"2.5", "3.5" }); + + types = result.Schema.GetMetadataTypes(HashAUnlim); + Assert.Equal(types.Select(x => x.Key), new string[1] { MetadataUtils.Kinds.KeyValues }); + result.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, HashA, ref keys); + Assert.True(keys.Length == 1024); + Assert.Equal(keys.Items().Select(x => x.Value.ToString()), new string[2] { "2.5", "3.5" }); + + types = result.Schema.GetMetadataTypes(HashAUnlimOrdered); + Assert.Equal(types.Select(x => x.Key), new string[1] { MetadataUtils.Kinds.KeyValues }); + result.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, HashA, ref keys); + Assert.True(keys.Length == 1024); + Assert.Equal(keys.Items().Select(x => x.Value.ToString()), new string[2] { "2.5", "3.5" }); + } + [Fact] + public void TestCommandLine() + { + Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Hash{col=B:A} in=f:\2.txt" }), (int)0); + } + + [Fact] + public void TestOldSavingAndLoading() + { + var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; + var dataView = ComponentCreation.CreateDataView(Env, data); + var pipe = new HashEstimator(Env, new[]{ + new HashTransform.ColumnInfo("A", "HashA", hashBits:4, invertHash:-1), + new HashTransform.ColumnInfo("B", "HashB", hashBits:3, ordered:true), + new HashTransform.ColumnInfo("C", "HashC", seed:42), + new HashTransform.ColumnInfo("A", "HashD"), + }); + var result = pipe.Fit(dataView).Transform(dataView); + var resultRoles = new RoleMappedData(result); + using (var ms = new MemoryStream()) + { + TrainUtils.SaveModel(Env, Env.Start("saving"), ms, null, resultRoles); + ms.Position = 0; + var loadedView = ModelFileUtils.LoadTransforms(Env, dataView, ms); + } + } } } From 1abf183c8ca016bb1a76a4e817886f34037c9414 Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Wed, 19 Sep 2018 11:55:06 -0700 Subject: [PATCH 4/8] address comments --- .../Commands/CrossValidationCommand.cs | 7 +- .../Transforms/HashTransform.cs | 113 ++++++++++-------- .../CategoricalHashTransform.cs | 11 +- .../Text/WordBagTransform.cs | 3 +- .../Text/WordHashBagTransform.cs | 11 +- .../DataPipe/TestDataPipe.cs | 17 +-- .../Transformers/HashTests.cs | 29 ++--- 7 files changed, 105 insertions(+), 86 deletions(-) diff --git a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs index f1c720ceef..e653443569 100644 --- a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs +++ b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs @@ -13,6 +13,7 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Internal.Calibration; using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Transforms; [assembly: LoadableClass(typeof(CrossValidationCommand), typeof(CrossValidationCommand.Arguments), typeof(SignatureCommand), "Cross Validation", CrossValidationCommand.LoadName)] @@ -329,10 +330,8 @@ private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output int inc = 0; while (input.Schema.TryGetColumnIndex(stratificationColumn, out tmp)) stratificationColumn = string.Format("{0}_{1:000}", origStratCol, ++inc); - var hashargs = new HashTransform.Arguments(); - hashargs.Column = new[] { new HashTransform.Column { Source = origStratCol, Name = stratificationColumn } }; - hashargs.HashBits = 30; - output = HashTransform.Create(Host, hashargs, input); + var est = new HashConverter(Host, stratificationColumn, origStratCol, 30); + output = est.Fit(input).Transform(input); } } diff --git a/src/Microsoft.ML.Data/Transforms/HashTransform.cs b/src/Microsoft.ML.Data/Transforms/HashTransform.cs index bd944a062d..9b3ada1494 100644 --- a/src/Microsoft.ML.Data/Transforms/HashTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/HashTransform.cs @@ -8,25 +8,26 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Transforms; using System; using System.Collections.Generic; using System.Linq; using System.Runtime.CompilerServices; using System.Text; -[assembly: LoadableClass(HashTransform.Summary, typeof(IDataTransform), typeof(HashTransform), typeof(HashTransform.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(HashConverterTransformer.Summary, typeof(IDataTransform), typeof(HashConverterTransformer), typeof(HashConverterTransformer.Arguments), typeof(SignatureDataTransform), "Hash Transform", "HashTransform", "Hash", DocName = "transform/HashTransform.md")] -[assembly: LoadableClass(HashTransform.Summary, typeof(IDataTransform), typeof(HashTransform), null, typeof(SignatureLoadDataTransform), - "Hash Transform", HashTransform.LoaderSignature)] +[assembly: LoadableClass(HashConverterTransformer.Summary, typeof(IDataTransform), typeof(HashConverterTransformer), null, typeof(SignatureLoadDataTransform), + "Hash Transform", HashConverterTransformer.LoaderSignature)] -[assembly: LoadableClass(HashTransform.Summary, typeof(HashTransform), null, typeof(SignatureLoadModel), - "Hash Transform", HashTransform.LoaderSignature)] +[assembly: LoadableClass(HashConverterTransformer.Summary, typeof(HashConverterTransformer), null, typeof(SignatureLoadModel), + "Hash Transform", HashConverterTransformer.LoaderSignature)] -[assembly: LoadableClass(typeof(IRowMapper), typeof(HashTransform), null, typeof(SignatureLoadRowMapper), - "Hash Transform", HashTransform.LoaderSignature)] +[assembly: LoadableClass(typeof(IRowMapper), typeof(HashConverterTransformer), null, typeof(SignatureLoadRowMapper), + "Hash Transform", HashConverterTransformer.LoaderSignature)] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Transforms { using Conditional = System.Diagnostics.ConditionalAttribute; @@ -35,7 +36,7 @@ namespace Microsoft.ML.Runtime.Data /// it hashes each slot separately. /// It can hash either text values or key values. /// - public sealed class HashTransform : OneToOneTransformerBase + public sealed class HashConverterTransformer : OneToOneTransformerBase { public sealed class Arguments { @@ -45,18 +46,18 @@ public sealed class Arguments [Argument(ArgumentType.AtMostOnce, HelpText = "Number of bits to hash into. Must be between 1 and 31, inclusive", ShortName = "bits", SortOrder = 2)] - public int HashBits = HashEstimator.Defaults.HashBits; + public int HashBits = HashConverter.Defaults.HashBits; [Argument(ArgumentType.AtMostOnce, HelpText = "Hashing seed")] - public uint Seed = HashEstimator.Defaults.Seed; + public uint Seed = HashConverter.Defaults.Seed; [Argument(ArgumentType.AtMostOnce, HelpText = "Whether the position of each term should be included in the hash", ShortName = "ord")] - public bool Ordered = HashEstimator.Defaults.Ordered; + public bool Ordered = HashConverter.Defaults.Ordered; [Argument(ArgumentType.AtMostOnce, HelpText = "Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit.", ShortName = "ih")] - public int InvertHash = HashEstimator.Defaults.InvertHash; + public int InvertHash = HashConverter.Defaults.InvertHash; } public sealed class Column : OneToOneColumn @@ -133,11 +134,16 @@ public sealed class ColumnInfo /// Whether the position of each term should be included in the hash. /// Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit. public ColumnInfo(string input, string output, - int hashBits = HashEstimator.Defaults.HashBits, - uint seed = HashEstimator.Defaults.Seed, - bool ordered = HashEstimator.Defaults.Ordered, - int invertHash = HashEstimator.Defaults.InvertHash) + int hashBits = HashConverter.Defaults.HashBits, + uint seed = HashConverter.Defaults.Seed, + bool ordered = HashConverter.Defaults.Ordered, + int invertHash = HashConverter.Defaults.InvertHash) { + if (invertHash < -1) + throw Contracts.ExceptParam(nameof(invertHash), "Value too small, must be -1 or larger"); + if (invertHash != 0 && hashBits >= 31) + throw Contracts.ExceptParam(nameof(hashBits), $"Cannot support invertHash for a {0} bit hash. 30 is the maximum possible.", hashBits); + Input = input; Output = output; HashBits = hashBits; @@ -155,7 +161,7 @@ internal ColumnInfo(string input, string output, ModelLoadContext ctx) // uint: HashSeed // byte: Ordered HashBits = ctx.Reader.ReadInt32(); - Contracts.CheckDecode(HashEstimator.NumBitsMin <= HashBits && HashBits < HashEstimator.NumBitsLim); + Contracts.CheckDecode(HashConverter.NumBitsMin <= HashBits && HashBits < HashConverter.NumBitsLim); Seed = ctx.Reader.ReadUInt32(); Ordered = ctx.Reader.ReadBoolByte(); } @@ -167,7 +173,7 @@ internal void Save(ModelSaveContext ctx) // uint: HashSeed // byte: Ordered - Contracts.Assert(HashEstimator.NumBitsMin <= HashBits && HashBits < HashEstimator.NumBitsLim); + Contracts.Assert(HashConverter.NumBitsMin <= HashBits && HashBits < HashConverter.NumBitsLim); ctx.Writer.Write(HashBits); ctx.Writer.Write(Seed); @@ -175,13 +181,6 @@ internal void Save(ModelSaveContext ctx) } } - public static string TestType(ColumnType type) - { - if (type.ItemType.IsText || type.ItemType.IsKey || type.ItemType == NumberType.R4 || type.ItemType == NumberType.R8) - return null; - return "Expected Text, Key, Single or Double item type"; - } - private const string RegistrationName = "Hash"; internal const string Summary = "Converts column values into hashes. This transform accepts text and keys as inputs. It works on single- and vector-valued columns, " @@ -206,17 +205,17 @@ private static VersionInfo GetVersionInfo() protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) { var type = inputSchema.GetColumnType(srcCol); - string reason = TestType(type); - if (reason != null) - throw Host.ExceptParam(nameof(inputSchema), reason); + if (!HashConverter.IsColumnTypeValid(type)) + throw Host.ExceptParam(nameof(inputSchema), HashConverter.ExpectedColumnType); } + private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) { - Contracts.CheckValue(columns, nameof(columns)); + Contracts.CheckNonEmpty(columns, nameof(columns)); return columns.Select(x => (x.Input, x.Output)).ToArray(); } - internal ColumnType GetOutputType(ISchema inputSchema, ColumnInfo column) + private ColumnType GetOutputType(ISchema inputSchema, ColumnInfo column) { var keyCount = column.HashBits < 31 ? 1 << column.HashBits : 0; inputSchema.TryGetColumnIndex(column.Input, out int srcCol); @@ -228,7 +227,18 @@ internal ColumnType GetOutputType(ISchema inputSchema, ColumnInfo column) return new VectorType(itemType, srcType.VectorSize); } - public HashTransform(IHostEnvironment env, IDataView input, ColumnInfo[] columns) : + public HashConverterTransformer(IHostEnvironment env, ColumnInfo[] columns): + base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns)) + { + _columns = columns.ToArray(); + foreach(var column in _columns) + { + if (column.InvertHash != 0) + throw Host.ExceptParam(nameof(columns), $"Found colunm with {nameof(column.InvertHash)} set to non zero value, please use { nameof(HashConverter)} instead"); + } + } + + internal HashConverterTransformer(IHostEnvironment env, IDataView input, ColumnInfo[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns)) { _columns = columns.ToArray(); @@ -307,7 +317,7 @@ internal Delegate GetGetterCore(IRow input, int iinfo, out Action disposer) protected override IRowMapper MakeRowMapper(ISchema schema) => new Mapper(this, schema); // Factory method for SignatureLoadModel. - private static HashTransform Create(IHostEnvironment env, ModelLoadContext ctx) + private static HashConverterTransformer Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); var host = env.Register(RegistrationName); @@ -315,10 +325,10 @@ private static HashTransform Create(IHostEnvironment env, ModelLoadContext ctx) host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return new HashTransform(host, ctx); + return new HashConverterTransformer(host, ctx); } - private HashTransform(IHost host, ModelLoadContext ctx) + private HashConverterTransformer(IHost host, ModelLoadContext ctx) : base(host, ctx) { var columnsLength = ColumnPairs.Length; @@ -339,7 +349,7 @@ public override void Save(ModelSaveContext ctx) // // - // Exes + // Host.Assert(_columns.Length == ColumnPairs.Length); foreach (var col in _columns) col.Save(ctx); @@ -375,7 +385,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV item.Ordered ?? args.Ordered, item.InvertHash ?? args.InvertHash); }; - return new HashTransform(env, input, cols).MakeDataTransform(input); + return new HashConverterTransformer(env, input, cols).MakeDataTransform(input); } #region Getters @@ -979,9 +989,9 @@ public ColInfo(string name, string source, ColumnType type) } private readonly ColumnType[] _types; - private readonly HashTransform _parent; + private readonly HashConverterTransformer _parent; - public Mapper(HashTransform parent, ISchema inputSchema) + public Mapper(HashConverterTransformer parent, ISchema inputSchema) : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) { _parent = parent; @@ -1285,7 +1295,7 @@ public override void Process() } } - public sealed class HashEstimator : IEstimator + public sealed class HashConverter : IEstimator { public const int NumBitsMin = 1; public const int NumBitsLim = 32; @@ -1299,18 +1309,24 @@ public static class Defaults } private readonly IHost _host; - private readonly HashTransform.ColumnInfo[] _columns; + private readonly HashConverterTransformer.ColumnInfo[] _columns; + + public static bool IsColumnTypeValid(ColumnType type) + { + return (type.ItemType.IsText || type.ItemType.IsKey || type.ItemType == NumberType.R4 || type.ItemType == NumberType.R8); + } + internal const string ExpectedColumnType = "Expected Text, Key, Single or Double item type"; - public HashEstimator(IHostEnvironment env, string name, string source = null, + public HashConverter(IHostEnvironment env, string name, string source = null, int hashBits = Defaults.HashBits, int invertHash = Defaults.InvertHash) - : this(env, new HashTransform.ColumnInfo(name, source ?? name, hashBits: hashBits, invertHash: invertHash)) + : this(env, new HashConverterTransformer.ColumnInfo(name, source ?? name, hashBits: hashBits, invertHash: invertHash)) { } - public HashEstimator(IHostEnvironment env, params HashTransform.ColumnInfo[] columns) + public HashConverter(IHostEnvironment env, params HashConverterTransformer.ColumnInfo[] columns) { Contracts.CheckValue(env, nameof(env)); - _host = env.Register(nameof(HashEstimator)); + _host = env.Register(nameof(HashConverter)); _columns = columns.ToArray(); foreach (var col in _columns) { @@ -1321,7 +1337,7 @@ public HashEstimator(IHostEnvironment env, params HashTransform.ColumnInfo[] col } } - public HashTransform Fit(IDataView input) => new HashTransform(_host, input, _columns); + public HashConverterTransformer Fit(IDataView input) => new HashConverterTransformer(_host, input, _columns); public SchemaShape GetOutputSchema(SchemaShape inputSchema) { @@ -1331,9 +1347,8 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) { if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); - string reason = HashTransform.TestType(col.ItemType); - if (reason != null) - throw _host.ExceptParam(nameof(inputSchema), reason); + if (!IsColumnTypeValid(col.ItemType)) + throw _host.ExceptParam(nameof(inputSchema), ExpectedColumnType); var metadata = new List(); if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.SlotNames, out var slotMeta)) metadata.Add(slotMeta); diff --git a/src/Microsoft.ML.Transforms/CategoricalHashTransform.cs b/src/Microsoft.ML.Transforms/CategoricalHashTransform.cs index 2f4c89eb67..7f1a20bffc 100644 --- a/src/Microsoft.ML.Transforms/CategoricalHashTransform.cs +++ b/src/Microsoft.ML.Transforms/CategoricalHashTransform.cs @@ -10,6 +10,7 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Transforms; [assembly: LoadableClass(CategoricalHashTransform.Summary, typeof(IDataTransform), typeof(CategoricalHashTransform), typeof(CategoricalHashTransform.Arguments), typeof(SignatureDataTransform), CategoricalHashTransform.UserName, "CategoricalHashTransform", "CatHashTransform", "CategoricalHash", "CatHash")] @@ -91,7 +92,7 @@ private static class Defaults } /// - /// This class is a merger of and + /// This class is a merger of and /// with join option removed /// public sealed class Arguments : TransformInputBase @@ -169,13 +170,13 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV throw h.ExceptUserArg(nameof(args.HashBits), "Number of bits must be between 1 and {0}", NumBitsLim - 1); // creating the Hash function - var hashArgs = new HashTransform.Arguments + var hashArgs = new HashConverterTransformer.Arguments { HashBits = args.HashBits, Seed = args.Seed, Ordered = args.Ordered, InvertHash = args.InvertHash, - Column = new HashTransform.Column[args.Column.Length] + Column = new HashConverterTransformer.Column[args.Column.Length] }; for (int i = 0; i < args.Column.Length; i++) { @@ -184,7 +185,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV throw h.ExceptUserArg(nameof(Column.Name)); h.Assert(!string.IsNullOrWhiteSpace(column.Name)); h.Assert(!string.IsNullOrWhiteSpace(column.Source)); - hashArgs.Column[i] = new HashTransform.Column + hashArgs.Column[i] = new HashConverterTransformer.Column { HashBits = column.HashBits, Seed = column.Seed, @@ -198,7 +199,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV return CreateTransformCore( args.OutputKind, args.Column, args.Column.Select(col => col.OutputKind).ToList(), - HashTransform.Create(h, hashArgs, input), + HashConverterTransformer.Create(h, hashArgs, input), h, args); } diff --git a/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs b/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs index 1340855547..bab50fb980 100644 --- a/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs @@ -13,6 +13,7 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Transforms; [assembly: LoadableClass(WordBagTransform.Summary, typeof(IDataTransform), typeof(WordBagTransform), typeof(WordBagTransform.Arguments), typeof(SignatureDataTransform), "Word Bag Transform", "WordBagTransform", "WordBag")] @@ -474,7 +475,7 @@ public interface INgramExtractorFactory { /// /// Whether the extractor transform created by this factory uses the hashing trick - /// (by using or , for example). + /// (by using or , for example). /// bool UseHashingTrick { get; } diff --git a/src/Microsoft.ML.Transforms/Text/WordHashBagTransform.cs b/src/Microsoft.ML.Transforms/Text/WordHashBagTransform.cs index 48d5a17c26..520e9d710d 100644 --- a/src/Microsoft.ML.Transforms/Text/WordHashBagTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/WordHashBagTransform.cs @@ -11,6 +11,7 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Transforms; [assembly: LoadableClass(WordHashBagTransform.Summary, typeof(IDataTransform), typeof(WordHashBagTransform), typeof(WordHashBagTransform.Arguments), typeof(SignatureDataTransform), "Word Hash Bag Transform", "WordHashBagTransform", "WordHashBag")] @@ -266,7 +267,7 @@ public bool TryUnparse(StringBuilder sb) } /// - /// This class is a merger of and + /// This class is a merger of and /// , with the ordered option, /// the rehashUnigrams option and the allLength option removed. /// @@ -340,7 +341,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV List termCols = null; if (termLoaderArgs != null) termCols = new List(); - var hashColumns = new List(); + var hashColumns = new List(); var ngramHashColumns = new NgramHashTransform.Column[args.Column.Length]; var colCount = args.Column.Length; @@ -371,7 +372,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV } hashColumns.Add( - new HashTransform.Column + new HashConverterTransformer.Column { Name = tmpName, Source = termLoaderArgs == null ? column.Source[isrc] : tmpName, @@ -435,7 +436,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV // Args for the Hash function with multiple columns var hashArgs = - new HashTransform.Arguments + new HashConverterTransformer.Arguments { HashBits = 31, Seed = args.Seed, @@ -444,7 +445,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV InvertHash = args.InvertHash }; - view = HashTransform.Create(h, hashArgs, view); + view = HashConverterTransformer.Create(h, hashArgs, view); // creating the NgramHash function var ngramHashArgs = diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs index 1d91e3b93c..99739d1fd1 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs @@ -14,6 +14,7 @@ using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.TextAnalytics; using Xunit; +using Microsoft.ML.Transforms; namespace Microsoft.ML.Runtime.RunTests { @@ -82,14 +83,14 @@ private void TestHashTransformHelper(T[] data, uint[] results, NumberType typ builder.AddColumn("F1", type, data); var srcView = builder.GetDataView(); - HashTransform.Column col = new HashTransform.Column(); + var col = new HashConverterTransformer.Column(); col.Name = "F1"; col.HashBits = 5; col.Seed = 42; - HashTransform.Arguments args = new HashTransform.Arguments(); - args.Column = new HashTransform.Column[] { col }; + var args = new HashConverterTransformer.Arguments(); + args.Column = new HashConverterTransformer.Column[] { col }; - var hashTransform = HashTransform.Create(Env, args, srcView); + var hashTransform = HashConverterTransformer.Create(Env, args, srcView); using (var cursor = hashTransform.GetRowCursor(c => true)) { var resultGetter = cursor.GetGetter(1); @@ -120,14 +121,14 @@ private void TestHashTransformVectorHelper(VBuffer data, uint[][] results, private void TestHashTransformVectorHelper(ArrayDataViewBuilder builder, uint[][] results) { var srcView = builder.GetDataView(); - HashTransform.Column col = new HashTransform.Column(); + var col = new HashConverterTransformer.Column(); col.Source = "F1V"; col.HashBits = 5; col.Seed = 42; - HashTransform.Arguments args = new HashTransform.Arguments(); - args.Column = new HashTransform.Column[] { col }; + var args = new HashConverterTransformer.Arguments(); + args.Column = new HashConverterTransformer.Column[] { col }; - var hashTransform = HashTransform.Create(Env, args, srcView); + var hashTransform = HashConverterTransformer.Create(Env, args, srcView); using (var cursor = hashTransform.GetRowCursor(c => true)) { var resultGetter = cursor.GetGetter>(1); diff --git a/test/Microsoft.ML.Tests/Transformers/HashTests.cs b/test/Microsoft.ML.Tests/Transformers/HashTests.cs index d51f00d26d..e97c29d890 100644 --- a/test/Microsoft.ML.Tests/Transformers/HashTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/HashTests.cs @@ -7,6 +7,7 @@ using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.RunTests; using Microsoft.ML.Runtime.Tools; +using Microsoft.ML.Transforms; using System.IO; using System.Linq; using Xunit; @@ -43,11 +44,11 @@ public void HashWorkout() var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; var dataView = ComponentCreation.CreateDataView(Env, data); - var pipe = new HashEstimator(Env, new[]{ - new HashTransform.ColumnInfo("A", "HashA", hashBits:4, invertHash:-1), - new HashTransform.ColumnInfo("B", "HashB", hashBits:3, ordered:true), - new HashTransform.ColumnInfo("C", "HashC", seed:42), - new HashTransform.ColumnInfo("A", "HashD"), + var pipe = new HashConverter(Env, new[]{ + new HashConverterTransformer.ColumnInfo("A", "HashA", hashBits:4, invertHash:-1), + new HashConverterTransformer.ColumnInfo("B", "HashB", hashBits:3, ordered:true), + new HashConverterTransformer.ColumnInfo("C", "HashC", seed:42), + new HashConverterTransformer.ColumnInfo("A", "HashD"), }); TestEstimatorCore(pipe, dataView); @@ -65,10 +66,10 @@ public void TestMetadata() var dataView = ComponentCreation.CreateDataView(Env, data); - var pipe = new HashEstimator(Env, new[] { - new HashTransform.ColumnInfo("A", "HashA", invertHash:1, hashBits:10), - new HashTransform.ColumnInfo("A", "HashAUnlim", invertHash:-1, hashBits:10), - new HashTransform.ColumnInfo("A", "HashAUnlimOrdered", invertHash:-1, hashBits:10, ordered:true) + var pipe = new HashConverter(Env, new[] { + new HashConverterTransformer.ColumnInfo("A", "HashA", invertHash:1, hashBits:10), + new HashConverterTransformer.ColumnInfo("A", "HashAUnlim", invertHash:-1, hashBits:10), + new HashConverterTransformer.ColumnInfo("A", "HashAUnlimOrdered", invertHash:-1, hashBits:10, ordered:true) }); var result = pipe.Fit(dataView).Transform(dataView); ValidateMetadata(result); @@ -111,11 +112,11 @@ public void TestOldSavingAndLoading() { var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; var dataView = ComponentCreation.CreateDataView(Env, data); - var pipe = new HashEstimator(Env, new[]{ - new HashTransform.ColumnInfo("A", "HashA", hashBits:4, invertHash:-1), - new HashTransform.ColumnInfo("B", "HashB", hashBits:3, ordered:true), - new HashTransform.ColumnInfo("C", "HashC", seed:42), - new HashTransform.ColumnInfo("A", "HashD"), + var pipe = new HashConverter(Env, new[]{ + new HashConverterTransformer.ColumnInfo("A", "HashA", hashBits:4, invertHash:-1), + new HashConverterTransformer.ColumnInfo("B", "HashB", hashBits:3, ordered:true), + new HashConverterTransformer.ColumnInfo("C", "HashC", seed:42), + new HashConverterTransformer.ColumnInfo("A", "HashD"), }); var result = pipe.Fit(dataView).Transform(dataView); var resultRoles = new RoleMappedData(result); From 2f887cea51a463df063c6f8336f3b3da663189f9 Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Wed, 19 Sep 2018 11:59:23 -0700 Subject: [PATCH 5/8] more changes --- src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs | 3 +-- src/Microsoft.ML.Data/Transforms/HashTransform.cs | 9 +-------- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs index e653443569..3a4a16864f 100644 --- a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs +++ b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs @@ -330,8 +330,7 @@ private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output int inc = 0; while (input.Schema.TryGetColumnIndex(stratificationColumn, out tmp)) stratificationColumn = string.Format("{0}_{1:000}", origStratCol, ++inc); - var est = new HashConverter(Host, stratificationColumn, origStratCol, 30); - output = est.Fit(input).Transform(input); + output = new HashConverter(Host, origStratCol, stratificationColumn, 30).Fit(input).Transform(input); } } diff --git a/src/Microsoft.ML.Data/Transforms/HashTransform.cs b/src/Microsoft.ML.Data/Transforms/HashTransform.cs index 9b3ada1494..e5f827ae6c 100644 --- a/src/Microsoft.ML.Data/Transforms/HashTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/HashTransform.cs @@ -1328,13 +1328,6 @@ public HashConverter(IHostEnvironment env, params HashConverterTransformer.Colum Contracts.CheckValue(env, nameof(env)); _host = env.Register(nameof(HashConverter)); _columns = columns.ToArray(); - foreach (var col in _columns) - { - if (col.InvertHash < -1) - throw _host.ExceptParam(nameof(columns), "Value too small, must be -1 or larger"); - if (col.InvertHash != 0 && col.HashBits >= 31) - throw _host.ExceptParam(nameof(columns), $"Cannot support invertHash for a {0} bit hash. 30 is the maximum possible.", col.HashBits); - } } public HashConverterTransformer Fit(IDataView input) => new HashConverterTransformer(_host, input, _columns); @@ -1354,7 +1347,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) metadata.Add(slotMeta); if (colInfo.InvertHash != 0) metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false)); - result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, col.ItemType.IsVector ? SchemaShape.Column.VectorKind.Vector : SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true, new SchemaShape(metadata.ToArray())); + result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, col.ItemType.IsVector ? SchemaShape.Column.VectorKind.Vector : SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true, new SchemaShape(metadata)); } return new SchemaShape(result.Values); } From a8a912506c9f453ee21872f9ce805363375ee037 Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Wed, 19 Sep 2018 13:05:16 -0700 Subject: [PATCH 6/8] Names --- .../Commands/CrossValidationCommand.cs | 2 +- .../Transforms/HashTransform.cs | 73 ++++++++++--------- .../CategoricalHashTransform.cs | 10 +-- .../Text/WordBagTransform.cs | 2 +- .../Text/WordHashBagTransform.cs | 10 +-- .../DataPipe/TestDataPipe.cs | 29 +++----- .../Transformers/HashTests.cs | 28 +++---- 7 files changed, 74 insertions(+), 80 deletions(-) diff --git a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs index 3a4a16864f..f6670394ee 100644 --- a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs +++ b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs @@ -330,7 +330,7 @@ private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output int inc = 0; while (input.Schema.TryGetColumnIndex(stratificationColumn, out tmp)) stratificationColumn = string.Format("{0}_{1:000}", origStratCol, ++inc); - output = new HashConverter(Host, origStratCol, stratificationColumn, 30).Fit(input).Transform(input); + output = new HashEstimator(Host, origStratCol, stratificationColumn, 30).Fit(input).Transform(input); } } diff --git a/src/Microsoft.ML.Data/Transforms/HashTransform.cs b/src/Microsoft.ML.Data/Transforms/HashTransform.cs index e5f827ae6c..b11b423649 100644 --- a/src/Microsoft.ML.Data/Transforms/HashTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/HashTransform.cs @@ -15,17 +15,17 @@ using System.Runtime.CompilerServices; using System.Text; -[assembly: LoadableClass(HashConverterTransformer.Summary, typeof(IDataTransform), typeof(HashConverterTransformer), typeof(HashConverterTransformer.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(HashTransformer.Summary, typeof(IDataTransform), typeof(HashTransformer), typeof(HashTransformer.Arguments), typeof(SignatureDataTransform), "Hash Transform", "HashTransform", "Hash", DocName = "transform/HashTransform.md")] -[assembly: LoadableClass(HashConverterTransformer.Summary, typeof(IDataTransform), typeof(HashConverterTransformer), null, typeof(SignatureLoadDataTransform), - "Hash Transform", HashConverterTransformer.LoaderSignature)] +[assembly: LoadableClass(HashTransformer.Summary, typeof(IDataTransform), typeof(HashTransformer), null, typeof(SignatureLoadDataTransform), + "Hash Transform", HashTransformer.LoaderSignature)] -[assembly: LoadableClass(HashConverterTransformer.Summary, typeof(HashConverterTransformer), null, typeof(SignatureLoadModel), - "Hash Transform", HashConverterTransformer.LoaderSignature)] +[assembly: LoadableClass(HashTransformer.Summary, typeof(HashTransformer), null, typeof(SignatureLoadModel), + "Hash Transform", HashTransformer.LoaderSignature)] -[assembly: LoadableClass(typeof(IRowMapper), typeof(HashConverterTransformer), null, typeof(SignatureLoadRowMapper), - "Hash Transform", HashConverterTransformer.LoaderSignature)] +[assembly: LoadableClass(typeof(IRowMapper), typeof(HashTransformer), null, typeof(SignatureLoadRowMapper), + "Hash Transform", HashTransformer.LoaderSignature)] namespace Microsoft.ML.Transforms { @@ -36,7 +36,7 @@ namespace Microsoft.ML.Transforms /// it hashes each slot separately. /// It can hash either text values or key values. /// - public sealed class HashConverterTransformer : OneToOneTransformerBase + public sealed class HashTransformer : OneToOneTransformerBase { public sealed class Arguments { @@ -46,18 +46,18 @@ public sealed class Arguments [Argument(ArgumentType.AtMostOnce, HelpText = "Number of bits to hash into. Must be between 1 and 31, inclusive", ShortName = "bits", SortOrder = 2)] - public int HashBits = HashConverter.Defaults.HashBits; + public int HashBits = HashEstimator.Defaults.HashBits; [Argument(ArgumentType.AtMostOnce, HelpText = "Hashing seed")] - public uint Seed = HashConverter.Defaults.Seed; + public uint Seed = HashEstimator.Defaults.Seed; [Argument(ArgumentType.AtMostOnce, HelpText = "Whether the position of each term should be included in the hash", ShortName = "ord")] - public bool Ordered = HashConverter.Defaults.Ordered; + public bool Ordered = HashEstimator.Defaults.Ordered; [Argument(ArgumentType.AtMostOnce, HelpText = "Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit.", ShortName = "ih")] - public int InvertHash = HashConverter.Defaults.InvertHash; + public int InvertHash = HashEstimator.Defaults.InvertHash; } public sealed class Column : OneToOneColumn @@ -115,6 +115,7 @@ public bool TryUnparse(StringBuilder sb) return TryUnparseCore(sb, extra); } } + public sealed class ColumnInfo { public readonly string Input; @@ -134,10 +135,10 @@ public sealed class ColumnInfo /// Whether the position of each term should be included in the hash. /// Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit. public ColumnInfo(string input, string output, - int hashBits = HashConverter.Defaults.HashBits, - uint seed = HashConverter.Defaults.Seed, - bool ordered = HashConverter.Defaults.Ordered, - int invertHash = HashConverter.Defaults.InvertHash) + int hashBits = HashEstimator.Defaults.HashBits, + uint seed = HashEstimator.Defaults.Seed, + bool ordered = HashEstimator.Defaults.Ordered, + int invertHash = HashEstimator.Defaults.InvertHash) { if (invertHash < -1) throw Contracts.ExceptParam(nameof(invertHash), "Value too small, must be -1 or larger"); @@ -161,7 +162,7 @@ internal ColumnInfo(string input, string output, ModelLoadContext ctx) // uint: HashSeed // byte: Ordered HashBits = ctx.Reader.ReadInt32(); - Contracts.CheckDecode(HashConverter.NumBitsMin <= HashBits && HashBits < HashConverter.NumBitsLim); + Contracts.CheckDecode(HashEstimator.NumBitsMin <= HashBits && HashBits < HashEstimator.NumBitsLim); Seed = ctx.Reader.ReadUInt32(); Ordered = ctx.Reader.ReadBoolByte(); } @@ -173,7 +174,7 @@ internal void Save(ModelSaveContext ctx) // uint: HashSeed // byte: Ordered - Contracts.Assert(HashConverter.NumBitsMin <= HashBits && HashBits < HashConverter.NumBitsLim); + Contracts.Assert(HashEstimator.NumBitsMin <= HashBits && HashBits < HashEstimator.NumBitsLim); ctx.Writer.Write(HashBits); ctx.Writer.Write(Seed); @@ -205,8 +206,8 @@ private static VersionInfo GetVersionInfo() protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) { var type = inputSchema.GetColumnType(srcCol); - if (!HashConverter.IsColumnTypeValid(type)) - throw Host.ExceptParam(nameof(inputSchema), HashConverter.ExpectedColumnType); + if (!HashEstimator.IsColumnTypeValid(type)) + throw Host.ExceptParam(nameof(inputSchema), HashEstimator.ExpectedColumnType); } private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) @@ -227,18 +228,18 @@ private ColumnType GetOutputType(ISchema inputSchema, ColumnInfo column) return new VectorType(itemType, srcType.VectorSize); } - public HashConverterTransformer(IHostEnvironment env, ColumnInfo[] columns): + public HashTransformer(IHostEnvironment env, ColumnInfo[] columns): base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns)) { _columns = columns.ToArray(); foreach(var column in _columns) { if (column.InvertHash != 0) - throw Host.ExceptParam(nameof(columns), $"Found colunm with {nameof(column.InvertHash)} set to non zero value, please use { nameof(HashConverter)} instead"); + throw Host.ExceptParam(nameof(columns), $"Found colunm with {nameof(column.InvertHash)} set to non zero value, please use { nameof(HashEstimator)} instead"); } } - internal HashConverterTransformer(IHostEnvironment env, IDataView input, ColumnInfo[] columns) : + internal HashTransformer(IHostEnvironment env, IDataView input, ColumnInfo[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns)) { _columns = columns.ToArray(); @@ -317,7 +318,7 @@ internal Delegate GetGetterCore(IRow input, int iinfo, out Action disposer) protected override IRowMapper MakeRowMapper(ISchema schema) => new Mapper(this, schema); // Factory method for SignatureLoadModel. - private static HashConverterTransformer Create(IHostEnvironment env, ModelLoadContext ctx) + private static HashTransformer Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); var host = env.Register(RegistrationName); @@ -325,10 +326,10 @@ private static HashConverterTransformer Create(IHostEnvironment env, ModelLoadCo host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return new HashConverterTransformer(host, ctx); + return new HashTransformer(host, ctx); } - private HashConverterTransformer(IHost host, ModelLoadContext ctx) + private HashTransformer(IHost host, ModelLoadContext ctx) : base(host, ctx) { var columnsLength = ColumnPairs.Length; @@ -385,7 +386,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV item.Ordered ?? args.Ordered, item.InvertHash ?? args.InvertHash); }; - return new HashConverterTransformer(env, input, cols).MakeDataTransform(input); + return new HashTransformer(env, input, cols).MakeDataTransform(input); } #region Getters @@ -989,9 +990,9 @@ public ColInfo(string name, string source, ColumnType type) } private readonly ColumnType[] _types; - private readonly HashConverterTransformer _parent; + private readonly HashTransformer _parent; - public Mapper(HashConverterTransformer parent, ISchema inputSchema) + public Mapper(HashTransformer parent, ISchema inputSchema) : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) { _parent = parent; @@ -1295,7 +1296,7 @@ public override void Process() } } - public sealed class HashConverter : IEstimator + public sealed class HashEstimator : IEstimator { public const int NumBitsMin = 1; public const int NumBitsLim = 32; @@ -1309,7 +1310,7 @@ public static class Defaults } private readonly IHost _host; - private readonly HashConverterTransformer.ColumnInfo[] _columns; + private readonly HashTransformer.ColumnInfo[] _columns; public static bool IsColumnTypeValid(ColumnType type) { @@ -1317,20 +1318,20 @@ public static bool IsColumnTypeValid(ColumnType type) } internal const string ExpectedColumnType = "Expected Text, Key, Single or Double item type"; - public HashConverter(IHostEnvironment env, string name, string source = null, + public HashEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null, int hashBits = Defaults.HashBits, int invertHash = Defaults.InvertHash) - : this(env, new HashConverterTransformer.ColumnInfo(name, source ?? name, hashBits: hashBits, invertHash: invertHash)) + : this(env, new HashTransformer.ColumnInfo(inputColumn, outputColumn ?? inputColumn, hashBits: hashBits, invertHash: invertHash)) { } - public HashConverter(IHostEnvironment env, params HashConverterTransformer.ColumnInfo[] columns) + public HashEstimator(IHostEnvironment env, params HashTransformer.ColumnInfo[] columns) { Contracts.CheckValue(env, nameof(env)); - _host = env.Register(nameof(HashConverter)); + _host = env.Register(nameof(HashEstimator)); _columns = columns.ToArray(); } - public HashConverterTransformer Fit(IDataView input) => new HashConverterTransformer(_host, input, _columns); + public HashTransformer Fit(IDataView input) => new HashTransformer(_host, input, _columns); public SchemaShape GetOutputSchema(SchemaShape inputSchema) { diff --git a/src/Microsoft.ML.Transforms/CategoricalHashTransform.cs b/src/Microsoft.ML.Transforms/CategoricalHashTransform.cs index 7f1a20bffc..5808a5e3e0 100644 --- a/src/Microsoft.ML.Transforms/CategoricalHashTransform.cs +++ b/src/Microsoft.ML.Transforms/CategoricalHashTransform.cs @@ -92,7 +92,7 @@ private static class Defaults } /// - /// This class is a merger of and + /// This class is a merger of and /// with join option removed /// public sealed class Arguments : TransformInputBase @@ -170,13 +170,13 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV throw h.ExceptUserArg(nameof(args.HashBits), "Number of bits must be between 1 and {0}", NumBitsLim - 1); // creating the Hash function - var hashArgs = new HashConverterTransformer.Arguments + var hashArgs = new HashTransformer.Arguments { HashBits = args.HashBits, Seed = args.Seed, Ordered = args.Ordered, InvertHash = args.InvertHash, - Column = new HashConverterTransformer.Column[args.Column.Length] + Column = new HashTransformer.Column[args.Column.Length] }; for (int i = 0; i < args.Column.Length; i++) { @@ -185,7 +185,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV throw h.ExceptUserArg(nameof(Column.Name)); h.Assert(!string.IsNullOrWhiteSpace(column.Name)); h.Assert(!string.IsNullOrWhiteSpace(column.Source)); - hashArgs.Column[i] = new HashConverterTransformer.Column + hashArgs.Column[i] = new HashTransformer.Column { HashBits = column.HashBits, Seed = column.Seed, @@ -199,7 +199,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV return CreateTransformCore( args.OutputKind, args.Column, args.Column.Select(col => col.OutputKind).ToList(), - HashConverterTransformer.Create(h, hashArgs, input), + HashTransformer.Create(h, hashArgs, input), h, args); } diff --git a/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs b/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs index bab50fb980..78437dfb1b 100644 --- a/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs @@ -475,7 +475,7 @@ public interface INgramExtractorFactory { /// /// Whether the extractor transform created by this factory uses the hashing trick - /// (by using or , for example). + /// (by using or , for example). /// bool UseHashingTrick { get; } diff --git a/src/Microsoft.ML.Transforms/Text/WordHashBagTransform.cs b/src/Microsoft.ML.Transforms/Text/WordHashBagTransform.cs index 520e9d710d..f99dbcb381 100644 --- a/src/Microsoft.ML.Transforms/Text/WordHashBagTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/WordHashBagTransform.cs @@ -267,7 +267,7 @@ public bool TryUnparse(StringBuilder sb) } /// - /// This class is a merger of and + /// This class is a merger of and /// , with the ordered option, /// the rehashUnigrams option and the allLength option removed. /// @@ -341,7 +341,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV List termCols = null; if (termLoaderArgs != null) termCols = new List(); - var hashColumns = new List(); + var hashColumns = new List(); var ngramHashColumns = new NgramHashTransform.Column[args.Column.Length]; var colCount = args.Column.Length; @@ -372,7 +372,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV } hashColumns.Add( - new HashConverterTransformer.Column + new HashTransformer.Column { Name = tmpName, Source = termLoaderArgs == null ? column.Source[isrc] : tmpName, @@ -436,7 +436,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV // Args for the Hash function with multiple columns var hashArgs = - new HashConverterTransformer.Arguments + new HashTransformer.Arguments { HashBits = 31, Seed = args.Seed, @@ -445,7 +445,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV InvertHash = args.InvertHash }; - view = HashConverterTransformer.Create(h, hashArgs, view); + view = HashTransformer.Create(h, hashArgs, view); // creating the NgramHash function var ngramHashArgs = diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs index 99739d1fd1..79b22b0e91 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs @@ -2,19 +2,12 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; - -using System; -using System.Collections.Generic; -using System.IO; -using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Data.IO; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.TextAnalytics; -using Xunit; using Microsoft.ML.Transforms; +using System; +using Xunit; +using Float = System.Single; namespace Microsoft.ML.Runtime.RunTests { @@ -83,14 +76,14 @@ private void TestHashTransformHelper(T[] data, uint[] results, NumberType typ builder.AddColumn("F1", type, data); var srcView = builder.GetDataView(); - var col = new HashConverterTransformer.Column(); + var col = new HashTransformer.Column(); col.Name = "F1"; col.HashBits = 5; col.Seed = 42; - var args = new HashConverterTransformer.Arguments(); - args.Column = new HashConverterTransformer.Column[] { col }; + var args = new HashTransformer.Arguments(); + args.Column = new HashTransformer.Column[] { col }; - var hashTransform = HashConverterTransformer.Create(Env, args, srcView); + var hashTransform = HashTransformer.Create(Env, args, srcView); using (var cursor = hashTransform.GetRowCursor(c => true)) { var resultGetter = cursor.GetGetter(1); @@ -121,14 +114,14 @@ private void TestHashTransformVectorHelper(VBuffer data, uint[][] results, private void TestHashTransformVectorHelper(ArrayDataViewBuilder builder, uint[][] results) { var srcView = builder.GetDataView(); - var col = new HashConverterTransformer.Column(); + var col = new HashTransformer.Column(); col.Source = "F1V"; col.HashBits = 5; col.Seed = 42; - var args = new HashConverterTransformer.Arguments(); - args.Column = new HashConverterTransformer.Column[] { col }; + var args = new HashTransformer.Arguments(); + args.Column = new HashTransformer.Column[] { col }; - var hashTransform = HashConverterTransformer.Create(Env, args, srcView); + var hashTransform = HashTransformer.Create(Env, args, srcView); using (var cursor = hashTransform.GetRowCursor(c => true)) { var resultGetter = cursor.GetGetter>(1); diff --git a/test/Microsoft.ML.Tests/Transformers/HashTests.cs b/test/Microsoft.ML.Tests/Transformers/HashTests.cs index e97c29d890..08108c6c1f 100644 --- a/test/Microsoft.ML.Tests/Transformers/HashTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/HashTests.cs @@ -44,11 +44,11 @@ public void HashWorkout() var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; var dataView = ComponentCreation.CreateDataView(Env, data); - var pipe = new HashConverter(Env, new[]{ - new HashConverterTransformer.ColumnInfo("A", "HashA", hashBits:4, invertHash:-1), - new HashConverterTransformer.ColumnInfo("B", "HashB", hashBits:3, ordered:true), - new HashConverterTransformer.ColumnInfo("C", "HashC", seed:42), - new HashConverterTransformer.ColumnInfo("A", "HashD"), + var pipe = new HashEstimator(Env, new[]{ + new HashTransformer.ColumnInfo("A", "HashA", hashBits:4, invertHash:-1), + new HashTransformer.ColumnInfo("B", "HashB", hashBits:3, ordered:true), + new HashTransformer.ColumnInfo("C", "HashC", seed:42), + new HashTransformer.ColumnInfo("A", "HashD"), }); TestEstimatorCore(pipe, dataView); @@ -66,10 +66,10 @@ public void TestMetadata() var dataView = ComponentCreation.CreateDataView(Env, data); - var pipe = new HashConverter(Env, new[] { - new HashConverterTransformer.ColumnInfo("A", "HashA", invertHash:1, hashBits:10), - new HashConverterTransformer.ColumnInfo("A", "HashAUnlim", invertHash:-1, hashBits:10), - new HashConverterTransformer.ColumnInfo("A", "HashAUnlimOrdered", invertHash:-1, hashBits:10, ordered:true) + var pipe = new HashEstimator(Env, new[] { + new HashTransformer.ColumnInfo("A", "HashA", invertHash:1, hashBits:10), + new HashTransformer.ColumnInfo("A", "HashAUnlim", invertHash:-1, hashBits:10), + new HashTransformer.ColumnInfo("A", "HashAUnlimOrdered", invertHash:-1, hashBits:10, ordered:true) }); var result = pipe.Fit(dataView).Transform(dataView); ValidateMetadata(result); @@ -112,11 +112,11 @@ public void TestOldSavingAndLoading() { var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; var dataView = ComponentCreation.CreateDataView(Env, data); - var pipe = new HashConverter(Env, new[]{ - new HashConverterTransformer.ColumnInfo("A", "HashA", hashBits:4, invertHash:-1), - new HashConverterTransformer.ColumnInfo("B", "HashB", hashBits:3, ordered:true), - new HashConverterTransformer.ColumnInfo("C", "HashC", seed:42), - new HashConverterTransformer.ColumnInfo("A", "HashD"), + var pipe = new HashEstimator(Env, new[]{ + new HashTransformer.ColumnInfo("A", "HashA", hashBits:4, invertHash:-1), + new HashTransformer.ColumnInfo("B", "HashB", hashBits:3, ordered:true), + new HashTransformer.ColumnInfo("C", "HashC", seed:42), + new HashTransformer.ColumnInfo("A", "HashD"), }); var result = pipe.Fit(dataView).Transform(dataView); var resultRoles = new RoleMappedData(result); From 4262aea76cfe0bda2c05ce4989d8affbc3594f37 Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Wed, 19 Sep 2018 13:10:55 -0700 Subject: [PATCH 7/8] unblock tests --- src/Microsoft.ML.Data/Transforms/HashTransform.cs | 4 ++-- test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/HashTransform.cs b/src/Microsoft.ML.Data/Transforms/HashTransform.cs index b11b423649..40cda02b90 100644 --- a/src/Microsoft.ML.Data/Transforms/HashTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/HashTransform.cs @@ -228,11 +228,11 @@ private ColumnType GetOutputType(ISchema inputSchema, ColumnInfo column) return new VectorType(itemType, srcType.VectorSize); } - public HashTransformer(IHostEnvironment env, ColumnInfo[] columns): + public HashTransformer(IHostEnvironment env, ColumnInfo[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns)) { _columns = columns.ToArray(); - foreach(var column in _columns) + foreach (var column in _columns) { if (column.InvertHash != 0) throw Host.ExceptParam(nameof(columns), $"Found colunm with {nameof(column.InvertHash)} set to non zero value, please use { nameof(HashEstimator)} instead"); diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs index 79b22b0e91..e8bc11143e 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs @@ -115,7 +115,7 @@ private void TestHashTransformVectorHelper(ArrayDataViewBuilder builder, uint[][ { var srcView = builder.GetDataView(); var col = new HashTransformer.Column(); - col.Source = "F1V"; + col.Name = "F1V"; col.HashBits = 5; col.Seed = 42; var args = new HashTransformer.Arguments(); From 47db03c14270046307aa50382eb831f92754c732 Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Wed, 19 Sep 2018 13:27:03 -0700 Subject: [PATCH 8/8] add comments --- .../Transforms/HashTransform.cs | 26 +++++++++++++++---- .../Transforms/TermEstimator.cs | 8 +++--- .../Transformers/HashTests.cs | 2 ++ 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/HashTransform.cs b/src/Microsoft.ML.Data/Transforms/HashTransform.cs index 40cda02b90..4e34aa2bca 100644 --- a/src/Microsoft.ML.Data/Transforms/HashTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/HashTransform.cs @@ -228,6 +228,11 @@ private ColumnType GetOutputType(ISchema inputSchema, ColumnInfo column) return new VectorType(itemType, srcType.VectorSize); } + /// + /// Constructor for case where you don't need to 'train' transform on data, e.g. InvertHash for all columns set to zero. + /// + /// Host Environment. + /// Description of dataset columns and how to process them. public HashTransformer(IHostEnvironment env, ColumnInfo[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns)) { @@ -303,7 +308,7 @@ internal HashTransformer(IHostEnvironment env, IDataView input, ColumnInfo[] col } } - internal Delegate GetGetterCore(IRow input, int iinfo, out Action disposer) + private Delegate GetGetterCore(IRow input, int iinfo, out Action disposer) { Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < _columns.Length); @@ -1296,6 +1301,9 @@ public override void Process() } } + /// + /// Estimator for + /// public sealed class HashEstimator : IEstimator { public const int NumBitsMin = 1; @@ -1312,18 +1320,26 @@ public static class Defaults private readonly IHost _host; private readonly HashTransformer.ColumnInfo[] _columns; - public static bool IsColumnTypeValid(ColumnType type) - { - return (type.ItemType.IsText || type.ItemType.IsKey || type.ItemType == NumberType.R4 || type.ItemType == NumberType.R8); - } + public static bool IsColumnTypeValid(ColumnType type) => (type.ItemType.IsText || type.ItemType.IsKey || type.ItemType == NumberType.R4 || type.ItemType == NumberType.R8); + internal const string ExpectedColumnType = "Expected Text, Key, Single or Double item type"; + /// + /// Convinence constructor for simple one column case + /// + /// Host Environment. + /// Name of the output column. + /// Name of the column to be transformed. If this is null '' will be used. + /// Number of bits to hash into. Must be between 1 and 31, inclusive. + /// Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit. public HashEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null, int hashBits = Defaults.HashBits, int invertHash = Defaults.InvertHash) : this(env, new HashTransformer.ColumnInfo(inputColumn, outputColumn ?? inputColumn, hashBits: hashBits, invertHash: invertHash)) { } + /// Host Environment. + /// Description of dataset columns and how to process them. public HashEstimator(IHostEnvironment env, params HashTransformer.ColumnInfo[] columns) { Contracts.CheckValue(env, nameof(env)); diff --git a/src/Microsoft.ML.Data/Transforms/TermEstimator.cs b/src/Microsoft.ML.Data/Transforms/TermEstimator.cs index 446140573c..9f8135252e 100644 --- a/src/Microsoft.ML.Data/Transforms/TermEstimator.cs +++ b/src/Microsoft.ML.Data/Transforms/TermEstimator.cs @@ -25,13 +25,13 @@ public static class Defaults /// Convenience constructor for public facing API. /// /// Host Environment. - /// Name of the output column. - /// Name of the column to be transformed. If this is null '' will be used. + /// Name of the output column. + /// Name of the column to be transformed. If this is null '' will be used. /// Maximum number of terms to keep per column when auto-training. /// How items should be ordered when vectorized. By default, they will be in the order encountered. /// If by value items are sorted according to their default comparison, e.g., text sorting will be case sensitive (e.g., 'A' then 'Z' then 'a'). - public TermEstimator(IHostEnvironment env, string name, string source = null, int maxNumTerms = Defaults.MaxNumTerms, TermTransform.SortOrder sort = Defaults.Sort) : - this(env, new TermTransform.ColumnInfo(name, source ?? name, maxNumTerms, sort)) + public TermEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null, int maxNumTerms = Defaults.MaxNumTerms, TermTransform.SortOrder sort = Defaults.Sort) : + this(env, new TermTransform.ColumnInfo(inputColumn, outputColumn ?? inputColumn, maxNumTerms, sort)) { } diff --git a/test/Microsoft.ML.Tests/Transformers/HashTests.cs b/test/Microsoft.ML.Tests/Transformers/HashTests.cs index 08108c6c1f..ac2393b991 100644 --- a/test/Microsoft.ML.Tests/Transformers/HashTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/HashTests.cs @@ -87,6 +87,7 @@ private void ValidateMetadata(IDataView result) Assert.Equal(types.Select(x => x.Key), new string[1] { MetadataUtils.Kinds.KeyValues }); result.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, HashA, ref keys); Assert.True(keys.Length == 1024); + //REVIEW: This is weird. I specified invertHash to 1 so I expect only one value to be in key values, but i got two. Assert.Equal(keys.Items().Select(x => x.Value.ToString()), new string[2] {"2.5", "3.5" }); types = result.Schema.GetMetadataTypes(HashAUnlim); @@ -101,6 +102,7 @@ private void ValidateMetadata(IDataView result) Assert.True(keys.Length == 1024); Assert.Equal(keys.Items().Select(x => x.Value.ToString()), new string[2] { "2.5", "3.5" }); } + [Fact] public void TestCommandLine() {