diff --git a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs index b0f2e12ef1..f6670394ee 100644 --- a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs +++ b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs @@ -13,6 +13,7 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Internal.Calibration; using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Transforms; [assembly: LoadableClass(typeof(CrossValidationCommand), typeof(CrossValidationCommand.Arguments), typeof(SignatureCommand), "Cross Validation", CrossValidationCommand.LoadName)] @@ -329,10 +330,7 @@ private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output int inc = 0; while (input.Schema.TryGetColumnIndex(stratificationColumn, out tmp)) stratificationColumn = string.Format("{0}_{1:000}", origStratCol, ++inc); - var hashargs = new HashTransform.Arguments(); - hashargs.Column = new[] { new HashTransform.Column { Source = origStratCol, Name = stratificationColumn } }; - hashargs.HashBits = 30; - output = new HashTransform(Host, hashargs, input); + output = new HashEstimator(Host, origStratCol, stratificationColumn, 30).Fit(input).Transform(input); } } diff --git a/src/Microsoft.ML.Data/Transforms/HashTransform.cs b/src/Microsoft.ML.Data/Transforms/HashTransform.cs index 8eec4e1581..a589121be6 100644 --- a/src/Microsoft.ML.Data/Transforms/HashTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/HashTransform.cs @@ -2,24 +2,32 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Collections.Generic; -using System.Linq; -using System.Runtime.CompilerServices; -using System.Text; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Transforms; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text; -[assembly: LoadableClass(HashTransform.Summary, typeof(HashTransform), typeof(HashTransform.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(HashTransformer.Summary, typeof(IDataTransform), typeof(HashTransformer), typeof(HashTransformer.Arguments), typeof(SignatureDataTransform), "Hash Transform", "HashTransform", "Hash", DocName = "transform/HashTransform.md")] -[assembly: LoadableClass(HashTransform.Summary, typeof(HashTransform), null, typeof(SignatureLoadDataTransform), - "Hash Transform", HashTransform.LoaderSignature)] +[assembly: LoadableClass(HashTransformer.Summary, typeof(IDataTransform), typeof(HashTransformer), null, typeof(SignatureLoadDataTransform), + "Hash Transform", HashTransformer.LoaderSignature)] + +[assembly: LoadableClass(HashTransformer.Summary, typeof(HashTransformer), null, typeof(SignatureLoadModel), + "Hash Transform", HashTransformer.LoaderSignature)] -namespace Microsoft.ML.Runtime.Data +[assembly: LoadableClass(typeof(IRowMapper), typeof(HashTransformer), null, typeof(SignatureLoadRowMapper), + "Hash Transform", HashTransformer.LoaderSignature)] + +namespace Microsoft.ML.Transforms { using Conditional = System.Diagnostics.ConditionalAttribute; @@ -28,19 +36,8 @@ namespace Microsoft.ML.Runtime.Data /// it hashes each slot separately. /// It can hash either text values or key values. /// - public sealed class HashTransform : OneToOneTransformBase, ITransformTemplate + public sealed class HashTransformer : OneToOneTransformerBase { - public const int NumBitsMin = 1; - public const int NumBitsLim = 32; - - private static class Defaults - { - public const int HashBits = NumBitsLim - 1; - public const uint Seed = 314489979; - public const bool Ordered = false; - public const int InvertHash = 0; - } - public sealed class Arguments { [Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", @@ -49,18 +46,18 @@ public sealed class Arguments [Argument(ArgumentType.AtMostOnce, HelpText = "Number of bits to hash into. Must be between 1 and 31, inclusive", ShortName = "bits", SortOrder = 2)] - public int HashBits = Defaults.HashBits; + public int HashBits = HashEstimator.Defaults.HashBits; [Argument(ArgumentType.AtMostOnce, HelpText = "Hashing seed")] - public uint Seed = Defaults.Seed; + public uint Seed = HashEstimator.Defaults.Seed; [Argument(ArgumentType.AtMostOnce, HelpText = "Whether the position of each term should be included in the hash", ShortName = "ord")] - public bool Ordered = Defaults.Ordered; + public bool Ordered = HashEstimator.Defaults.Ordered; [Argument(ArgumentType.AtMostOnce, HelpText = "Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit.", ShortName = "ih")] - public int InvertHash = Defaults.InvertHash; + public int InvertHash = HashEstimator.Defaults.InvertHash; } public sealed class Column : OneToOneColumn @@ -95,8 +92,7 @@ protected override bool TryParse(string str) // We accept N:B:S where N is the new column name, B is the number of bits, // and S is source column names. - string extra; - if (!base.TryParse(str, out extra)) + if (!base.TryParse(str, out string extra)) return false; if (extra == null) return true; @@ -120,57 +116,72 @@ public bool TryUnparse(StringBuilder sb) } } - private sealed class ColInfoEx + public sealed class ColumnInfo { + public readonly string Input; + public readonly string Output; public readonly int HashBits; - public readonly uint HashSeed; + public readonly uint Seed; public readonly bool Ordered; + public readonly int InvertHash; - public ColInfoEx(Arguments args, Column col) + /// + /// Describes how the transformer handles one column pair. + /// + /// Name of input column. + /// Name of output column. + /// Number of bits to hash into. Must be between 1 and 31, inclusive. + /// Hashing seed. + /// Whether the position of each term should be included in the hash. + /// Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit. + public ColumnInfo(string input, string output, + int hashBits = HashEstimator.Defaults.HashBits, + uint seed = HashEstimator.Defaults.Seed, + bool ordered = HashEstimator.Defaults.Ordered, + int invertHash = HashEstimator.Defaults.InvertHash) { - HashBits = col.HashBits ?? args.HashBits; - if (HashBits < NumBitsMin || HashBits >= NumBitsLim) - throw Contracts.ExceptUserArg(nameof(args.HashBits), "Should be between {0} and {1} inclusive", NumBitsMin, NumBitsLim - 1); - HashSeed = col.Seed ?? args.Seed; - Ordered = col.Ordered ?? args.Ordered; + if (invertHash < -1) + throw Contracts.ExceptParam(nameof(invertHash), "Value too small, must be -1 or larger"); + if (invertHash != 0 && hashBits >= 31) + throw Contracts.ExceptParam(nameof(hashBits), $"Cannot support invertHash for a {0} bit hash. 30 is the maximum possible.", hashBits); + + Input = input; + Output = output; + HashBits = hashBits; + Seed = seed; + Ordered = ordered; + InvertHash = invertHash; } - public ColInfoEx(ModelLoadContext ctx) + internal ColumnInfo(string input, string output, ModelLoadContext ctx) { + Input = input; + Output = output; // *** Binary format *** // int: HashBits // uint: HashSeed // byte: Ordered - HashBits = ctx.Reader.ReadInt32(); - Contracts.CheckDecode(NumBitsMin <= HashBits && HashBits < NumBitsLim); - - HashSeed = ctx.Reader.ReadUInt32(); + Contracts.CheckDecode(HashEstimator.NumBitsMin <= HashBits && HashBits < HashEstimator.NumBitsLim); + Seed = ctx.Reader.ReadUInt32(); Ordered = ctx.Reader.ReadBoolByte(); } - public void Save(ModelSaveContext ctx) + internal void Save(ModelSaveContext ctx) { // *** Binary format *** // int: HashBits // uint: HashSeed // byte: Ordered - Contracts.Assert(NumBitsMin <= HashBits && HashBits < NumBitsLim); + Contracts.Assert(HashEstimator.NumBitsMin <= HashBits && HashBits < HashEstimator.NumBitsLim); ctx.Writer.Write(HashBits); - ctx.Writer.Write(HashSeed); + ctx.Writer.Write(Seed); ctx.Writer.WriteBoolByte(Ordered); } } - private static string TestType(ColumnType type) - { - if (type.ItemType.IsText || type.ItemType.IsKey || type.ItemType == NumberType.R4 || type.ItemType == NumberType.R8) - return null; - return "Expected Text, Key, Single or Double item type"; - } - private const string RegistrationName = "Hash"; internal const string Summary = "Converts column values into hashes. This transform accepts text and keys as inputs. It works on single- and vector-valued columns, " @@ -188,109 +199,81 @@ private static VersionInfo GetVersionInfo() loaderSignature: LoaderSignature); } - private readonly ColInfoEx[] _exes; - private readonly ColumnType[] _types; - + private readonly ColumnInfo[] _columns; private readonly VBuffer>[] _keyValues; private readonly ColumnType[] _kvTypes; - public static HashTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) { - Contracts.CheckValue(env, nameof(env)); - var h = env.Register(RegistrationName); - h.CheckValue(ctx, nameof(ctx)); - h.CheckValue(input, nameof(input)); - ctx.CheckAtModel(GetVersionInfo()); - return h.Apply("Loading Model", ch => new HashTransform(h, ctx, input)); + var type = inputSchema.GetColumnType(srcCol); + if (!HashEstimator.IsColumnTypeValid(type)) + throw Host.ExceptParam(nameof(inputSchema), HashEstimator.ExpectedColumnType); } - private HashTransform(IHost host, ModelLoadContext ctx, IDataView input) - : base(host, ctx, input, TestType) + private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) { - Host.AssertValue(ctx); - - // *** Binary format *** - // - // Exes - - Host.AssertNonEmpty(Infos); - _exes = new ColInfoEx[Infos.Length]; - for (int iinfo = 0; iinfo < Infos.Length; iinfo++) - _exes[iinfo] = new ColInfoEx(ctx); - - _types = InitColumnTypes(); - - TextModelHelper.LoadAll(Host, ctx, Infos.Length, out _keyValues, out _kvTypes); - SetMetadata(); + Contracts.CheckNonEmpty(columns, nameof(columns)); + return columns.Select(x => (x.Input, x.Output)).ToArray(); } - public override void Save(ModelSaveContext ctx) + private ColumnType GetOutputType(ISchema inputSchema, ColumnInfo column) { - Host.CheckValue(ctx, nameof(ctx)); - ctx.CheckAtModel(); - ctx.SetVersionInfo(GetVersionInfo()); - - // - // - // Exes - - SaveBase(ctx); - Host.Assert(_exes.Length == Infos.Length); - for (int iinfo = 0; iinfo < Infos.Length; iinfo++) - _exes[iinfo].Save(ctx); - - TextModelHelper.SaveAll(Host, ctx, Infos.Length, _keyValues); + var keyCount = column.HashBits < 31 ? 1 << column.HashBits : 0; + inputSchema.TryGetColumnIndex(column.Input, out int srcCol); + var itemType = new KeyType(DataKind.U4, 0, keyCount, keyCount > 0); + var srcType = inputSchema.GetColumnType(srcCol); + if (!srcType.IsVector) + return itemType; + else + return new VectorType(itemType, srcType.VectorSize); } /// - /// Convenience constructor for public facing API. + /// Constructor for case where you don't need to 'train' transform on data, e.g. InvertHash for all columns set to zero. /// /// Host Environment. - /// Input . This is the output from previous transform or loader. - /// Name of the output column. - /// Name of the column to be transformed. If this is null '' will be used. - /// Number of bits to hash into. Must be between 1 and 31, inclusive. - /// Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit. - public HashTransform(IHostEnvironment env, - IDataView input, - string name, - string source = null, - int hashBits = Defaults.HashBits, - int invertHash = Defaults.InvertHash) - : this(env, new Arguments() { - Column = new[] { new Column() { Source = source ?? name, Name = name } }, - HashBits = hashBits, InvertHash = invertHash }, input) + /// Description of dataset columns and how to process them. + public HashTransformer(IHostEnvironment env, ColumnInfo[] columns) : + base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns)) { + _columns = columns.ToArray(); + foreach (var column in _columns) + { + if (column.InvertHash != 0) + throw Host.ExceptParam(nameof(columns), $"Found colunm with {nameof(column.InvertHash)} set to non zero value, please use { nameof(HashEstimator)} instead"); + } } - public HashTransform(IHostEnvironment env, Arguments args, IDataView input) - : base(Contracts.CheckRef(env, nameof(env)), RegistrationName, env.CheckRef(args, nameof(args)).Column, - input, TestType) + internal HashTransformer(IHostEnvironment env, IDataView input, ColumnInfo[] columns) : + base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns)) { - if (args.HashBits < NumBitsMin || args.HashBits >= NumBitsLim) - throw Host.ExceptUserArg(nameof(args.HashBits), "hashBits should be between {0} and {1} inclusive", NumBitsMin, NumBitsLim - 1); - - _exes = new ColInfoEx[Infos.Length]; + _columns = columns.ToArray(); + var types = new ColumnType[_columns.Length]; List invertIinfos = null; List invertHashMaxCounts = null; - for (int iinfo = 0; iinfo < Infos.Length; iinfo++) + HashSet sourceColumnsForInvertHash = new HashSet(); + for (int i = 0; i < _columns.Length; i++) { - _exes[iinfo] = new ColInfoEx(args, args.Column[iinfo]); - int invertHashMaxCount = GetAndVerifyInvertHashMaxCount(args, args.Column[iinfo], _exes[iinfo]); + if (!input.Schema.TryGetColumnIndex(ColumnPairs[i].input, out int srcCol)) + throw Host.ExceptSchemaMismatch(nameof(input), "input", ColumnPairs[i].input); + CheckInputColumn(input.Schema, i, srcCol); + + types[i] = GetOutputType(input.Schema, _columns[i]); + int invertHashMaxCount; + if (_columns[i].InvertHash == -1) + invertHashMaxCount = int.MaxValue; + else + invertHashMaxCount = _columns[i].InvertHash; if (invertHashMaxCount > 0) { - Utils.Add(ref invertIinfos, iinfo); + Utils.Add(ref invertIinfos, i); Utils.Add(ref invertHashMaxCounts, invertHashMaxCount); + sourceColumnsForInvertHash.Add(srcCol); } } - - _types = InitColumnTypes(); - - if (Utils.Size(invertIinfos) > 0) + if (Utils.Size(sourceColumnsForInvertHash) > 0) { - // Build the invert hashes for all columns for which it was requested. - var srcs = new HashSet(invertIinfos.Select(i => Infos[i].Source)); - using (IRowCursor srcCursor = input.GetRowCursor(srcs.Contains)) + using (IRowCursor srcCursor = input.GetRowCursor(sourceColumnsForInvertHash.Contains)) { using (var ch = Host.Start("Invert hash building")) { @@ -299,154 +282,146 @@ public HashTransform(IHostEnvironment env, Arguments args, IDataView input) for (int i = 0; i < helpers.Length; ++i) { int iinfo = invertIinfos[i]; - Host.Assert(_types[iinfo].ItemType.KeyCount > 0); - var dstGetter = GetGetterCore(ch, srcCursor, iinfo, out disposer); + Host.Assert(types[iinfo].ItemType.KeyCount > 0); + var dstGetter = GetGetterCore(srcCursor, iinfo, out disposer); Host.Assert(disposer == null); - var ex = _exes[iinfo]; + var ex = _columns[iinfo]; var maxCount = invertHashMaxCounts[i]; - helpers[i] = InvertHashHelper.Create(srcCursor, Infos[iinfo], ex, maxCount, dstGetter); + helpers[i] = InvertHashHelper.Create(srcCursor, ex, maxCount, dstGetter); } while (srcCursor.MoveNext()) { for (int i = 0; i < helpers.Length; ++i) helpers[i].Process(); } - _keyValues = new VBuffer>[_exes.Length]; - _kvTypes = new ColumnType[_exes.Length]; + _keyValues = new VBuffer>[_columns.Length]; + _kvTypes = new ColumnType[_columns.Length]; for (int i = 0; i < helpers.Length; ++i) { _keyValues[invertIinfos[i]] = helpers[i].GetKeyValuesMetadata(); - Host.Assert(_keyValues[invertIinfos[i]].Length == _types[invertIinfos[i]].ItemType.KeyCount); + Host.Assert(_keyValues[invertIinfos[i]].Length == types[invertIinfos[i]].ItemType.KeyCount); _kvTypes[invertIinfos[i]] = new VectorType(TextType.Instance, _keyValues[invertIinfos[i]].Length); } ch.Done(); } } } - SetMetadata(); } - /// - /// Re-apply constructor. - /// - private HashTransform(IHostEnvironment env, HashTransform transform, IDataView newSource) - : base(env, RegistrationName, transform, newSource, TestType) + private Delegate GetGetterCore(IRow input, int iinfo, out Action disposer) { - _exes = transform._exes; - _types = InitColumnTypes(); - _keyValues = transform._keyValues; - _kvTypes = transform._kvTypes; - SetMetadata(); + Host.AssertValue(input); + Host.Assert(0 <= iinfo && iinfo < _columns.Length); + disposer = null; + input.Schema.TryGetColumnIndex(_columns[iinfo].Input, out int srcCol); + var srcType = input.Schema.GetColumnType(srcCol); + if (!srcType.IsVector) + return ComposeGetterOne(input, iinfo, srcCol, srcType); + return ComposeGetterVec(input, iinfo, srcCol, srcType); } - public IDataTransform ApplyToData(IHostEnvironment env, IDataView newSource) - { - return new HashTransform(env, this, newSource); - } + protected override IRowMapper MakeRowMapper(ISchema schema) => new Mapper(this, schema); - private static int GetAndVerifyInvertHashMaxCount(Arguments args, Column col, ColInfoEx ex) + // Factory method for SignatureLoadModel. + private static HashTransformer Create(IHostEnvironment env, ModelLoadContext ctx) { - var invertHashMaxCount = col.InvertHash ?? args.InvertHash; - if (invertHashMaxCount != 0) - { - if (invertHashMaxCount == -1) - invertHashMaxCount = int.MaxValue; - Contracts.CheckUserArg(invertHashMaxCount > 0, nameof(args.InvertHash), "Value too small, must be -1 or larger"); - // If the bits is 31 or higher, we can't declare a KeyValues of the appropriate length, - // this requiring a VBuffer of length 1u << 31 which exceeds int.MaxValue. - if (ex.HashBits >= 31) - throw Contracts.ExceptUserArg(nameof(args.InvertHash), "Cannot support invertHash for a {0} bit hash. 30 is the maximum possible.", ex.HashBits); - } - return invertHashMaxCount; - } + Contracts.CheckValue(env, nameof(env)); + var host = env.Register(RegistrationName); - private ColumnType[] InitColumnTypes() - { - var types = new ColumnType[Infos.Length]; - for (int iinfo = 0; iinfo < Infos.Length; iinfo++) - { - var keyCount = _exes[iinfo].HashBits < 31 ? 1 << _exes[iinfo].HashBits : 0; - var itemType = new KeyType(DataKind.U4, 0, keyCount, keyCount > 0); - if (!Infos[iinfo].TypeSrc.IsVector) - types[iinfo] = itemType; - else - types[iinfo] = new VectorType(itemType, Infos[iinfo].TypeSrc.VectorSize); - } - return types; - } + host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); - protected override ColumnType GetColumnTypeCore(int iinfo) - { - Host.Check(0 <= iinfo & iinfo < Infos.Length); - return _types[iinfo]; + return new HashTransformer(host, ctx); } - private void SetMetadata() + private HashTransformer(IHost host, ModelLoadContext ctx) + : base(host, ctx) { - var md = Metadata; - for (int iinfo = 0; iinfo < Infos.Length; iinfo++) - { - using (var bldr = md.BuildMetadata(iinfo, Source.Schema, Infos[iinfo].Source, - MetadataUtils.Kinds.SlotNames)) - { - if (_kvTypes != null && _kvTypes[iinfo] != null) - bldr.AddGetter>>(MetadataUtils.Kinds.KeyValues, _kvTypes[iinfo], GetTerms); - } - } - md.Seal(); + var columnsLength = ColumnPairs.Length; + _columns = new ColumnInfo[columnsLength]; + for (int i = 0; i < columnsLength; i++) + _columns[i] = new ColumnInfo(ColumnPairs[i].input, ColumnPairs[i].output, ctx); + TextModelHelper.LoadAll(Host, ctx, columnsLength, out _keyValues, out _kvTypes); } - private void GetTerms(int iinfo, ref VBuffer> dst) + public override void Save(ModelSaveContext ctx) { - Host.Assert(0 <= iinfo && iinfo < Infos.Length); - Host.Assert(Utils.Size(_keyValues) == Infos.Length); - Host.Assert(_keyValues[iinfo].Length > 0); - _keyValues[iinfo].CopyTo(ref dst); + Host.CheckValue(ctx, nameof(ctx)); + + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); + + SaveColumns(ctx); + + // + // + // + Host.Assert(_columns.Length == ColumnPairs.Length); + foreach (var col in _columns) + col.Save(ctx); + + TextModelHelper.SaveAll(Host, ctx, _columns.Length, _keyValues); } - protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) + // Factory method for SignatureLoadDataTransform. + private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + => Create(env, ctx).MakeDataTransform(input); + + // Factory method for SignatureLoadRowMapper. + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); + + // Factory method for SignatureDataTransform. + public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) { - Host.AssertValueOrNull(ch); - Host.AssertValue(input); - Host.Assert(0 <= iinfo && iinfo < Infos.Length); - disposer = null; + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(args, nameof(args)); + env.CheckValue(input, nameof(input)); - if (!Infos[iinfo].TypeSrc.IsVector) - return ComposeGetterOne(input, iinfo); - return ComposeGetterVec(input, iinfo); + env.CheckValue(args.Column, nameof(args.Column)); + var cols = new ColumnInfo[args.Column.Length]; + for (int i = 0; i < cols.Length; i++) + { + var item = args.Column[i]; + var kind = item.InvertHash ?? args.InvertHash; + cols[i] = new ColumnInfo(item.Source ?? item.Name, + item.Name, + item.HashBits ?? args.HashBits, + item.Seed ?? args.Seed, + item.Ordered ?? args.Ordered, + item.InvertHash ?? args.InvertHash); + }; + return new HashTransformer(env, input, cols).MakeDataTransform(input); } - /// - /// Getter generator for single valued inputs - /// - private ValueGetter ComposeGetterOne(IRow input, int iinfo) + #region Getters + private ValueGetter ComposeGetterOne(IRow input, int iinfo, int srcCol, ColumnType srcType) { - var colType = Infos[iinfo].TypeSrc; - Host.Assert(colType.IsText || colType.IsKey || colType == NumberType.R4 || colType == NumberType.R8); + Host.Assert(srcType.IsText || srcType.IsKey || srcType == NumberType.R4 || srcType == NumberType.R8); - var mask = (1U << _exes[iinfo].HashBits) - 1; - uint seed = _exes[iinfo].HashSeed; + var mask = (1U << _columns[iinfo].HashBits) - 1; + uint seed = _columns[iinfo].Seed; // In case of single valued input column, hash in 0 for the slot index. - if (_exes[iinfo].Ordered) + if (_columns[iinfo].Ordered) seed = Hashing.MurmurRound(seed, 0); - switch (colType.RawKind) + switch (srcType.RawKind) { - case DataKind.Text: - return ComposeGetterOneCore(GetSrcGetter>(input, iinfo), seed, mask); - case DataKind.U1: - return ComposeGetterOneCore(GetSrcGetter(input, iinfo), seed, mask); - case DataKind.U2: - return ComposeGetterOneCore(GetSrcGetter(input, iinfo), seed, mask); - case DataKind.U4: - return ComposeGetterOneCore(GetSrcGetter(input, iinfo), seed, mask); - case DataKind.R4: - return ComposeGetterOneCore(GetSrcGetter(input, iinfo), seed, mask); - case DataKind.R8: - return ComposeGetterOneCore(GetSrcGetter(input, iinfo), seed, mask); - default: - Host.Assert(colType.RawKind == DataKind.U8); - return ComposeGetterOneCore(GetSrcGetter(input, iinfo), seed, mask); + case DataKind.Text: + return ComposeGetterOneCore(input.GetGetter>(srcCol), seed, mask); + case DataKind.U1: + return ComposeGetterOneCore(input.GetGetter(srcCol), seed, mask); + case DataKind.U2: + return ComposeGetterOneCore(input.GetGetter(srcCol), seed, mask); + case DataKind.U4: + return ComposeGetterOneCore(input.GetGetter(srcCol), seed, mask); + case DataKind.R4: + return ComposeGetterOneCore(input.GetGetter(srcCol), seed, mask); + case DataKind.R8: + return ComposeGetterOneCore(input.GetGetter(srcCol), seed, mask); + default: + Host.Assert(srcType.RawKind == DataKind.U8); + return ComposeGetterOneCore(input.GetGetter(srcCol), seed, mask); } } @@ -537,43 +512,42 @@ private ValueGetter ComposeGetterOneCore(ValueGetter getSrc, uint // them (either with their index or without) into dst. Additionally it fills in zero hashes in the rest of dst elements. private delegate void HashLoopWithZeroHash(int count, int[] indices, TSrc[] src, uint[] dst, int dstCount, uint seed, uint mask); - private ValueGetter> ComposeGetterVec(IRow input, int iinfo) + private ValueGetter> ComposeGetterVec(IRow input, int iinfo, int srcCol, ColumnType srcType) { - var colType = Infos[iinfo].TypeSrc; - Host.Assert(colType.IsVector); - Host.Assert(colType.ItemType.IsText || colType.ItemType.IsKey || colType.ItemType == NumberType.R4 || colType.ItemType == NumberType.R8); + Host.Assert(srcType.IsVector); + Host.Assert(srcType.ItemType.IsText || srcType.ItemType.IsKey || srcType.ItemType == NumberType.R4 || srcType.ItemType == NumberType.R8); - switch (colType.ItemType.RawKind) + switch (srcType.ItemType.RawKind) { - case DataKind.Text: - return ComposeGetterVecCore>(input, iinfo, HashUnord, HashDense, HashSparse); - case DataKind.U1: - return ComposeGetterVecCore(input, iinfo, HashUnord, HashDense, HashSparse); - case DataKind.U2: - return ComposeGetterVecCore(input, iinfo, HashUnord, HashDense, HashSparse); - case DataKind.U4: - return ComposeGetterVecCore(input, iinfo, HashUnord, HashDense, HashSparse); - case DataKind.R4: - return ComposeGetterVecCoreFloat(input, iinfo, HashSparseUnord, HashUnord, HashDense); - case DataKind.R8: - return ComposeGetterVecCoreFloat(input, iinfo, HashSparseUnord, HashUnord, HashDense); - default: - Host.Assert(colType.ItemType.RawKind == DataKind.U8); - return ComposeGetterVecCore(input, iinfo, HashUnord, HashDense, HashSparse); + case DataKind.Text: + return ComposeGetterVecCore>(input, iinfo, srcCol, srcType, HashUnord, HashDense, HashSparse); + case DataKind.U1: + return ComposeGetterVecCore(input, iinfo, srcCol, srcType, HashUnord, HashDense, HashSparse); + case DataKind.U2: + return ComposeGetterVecCore(input, iinfo, srcCol, srcType, HashUnord, HashDense, HashSparse); + case DataKind.U4: + return ComposeGetterVecCore(input, iinfo, srcCol, srcType, HashUnord, HashDense, HashSparse); + case DataKind.R4: + return ComposeGetterVecCoreFloat(input, iinfo, srcCol, srcType, HashSparseUnord, HashUnord, HashDense); + case DataKind.R8: + return ComposeGetterVecCoreFloat(input, iinfo, srcCol, srcType, HashSparseUnord, HashUnord, HashDense); + default: + Host.Assert(srcType.ItemType.RawKind == DataKind.U8); + return ComposeGetterVecCore(input, iinfo, srcCol, srcType, HashUnord, HashDense, HashSparse); } } - private ValueGetter> ComposeGetterVecCore(IRow input, int iinfo, + private ValueGetter> ComposeGetterVecCore(IRow input, int iinfo, int srcCol, ColumnType srcType, HashLoop hasherUnord, HashLoop hasherDense, HashLoop hasherSparse) { - Host.Assert(Infos[iinfo].TypeSrc.IsVector); - Host.Assert(Infos[iinfo].TypeSrc.ItemType.RawType == typeof(T)); + Host.Assert(srcType.IsVector); + Host.Assert(srcType.ItemType.RawType == typeof(T)); - var getSrc = GetSrcGetter>(input, iinfo); - var ex = _exes[iinfo]; + var getSrc = input.GetGetter>(srcCol); + var ex = _columns[iinfo]; var mask = (1U << ex.HashBits) - 1; - var seed = ex.HashSeed; - var len = Infos[iinfo].TypeSrc.VectorSize; + var seed = ex.Seed; + var len = srcType.VectorSize; var src = default(VBuffer); if (!ex.Ordered) @@ -612,20 +586,20 @@ private ValueGetter> ComposeGetterVecCore(IRow input, int iinfo }; } - private ValueGetter> ComposeGetterVecCoreFloat(IRow input, int iinfo, + private ValueGetter> ComposeGetterVecCoreFloat(IRow input, int iinfo, int srcCol, ColumnType srcType, HashLoopWithZeroHash hasherSparseUnord, HashLoop hasherDenseUnord, HashLoop hasherDenseOrdered) { - Host.Assert(Infos[iinfo].TypeSrc.IsVector); - Host.Assert(Infos[iinfo].TypeSrc.ItemType.RawType == typeof(T)); + Host.Assert(srcType.IsVector); + Host.Assert(srcType.ItemType.RawType == typeof(T)); - var getSrc = GetSrcGetter>(input, iinfo); - var ex = _exes[iinfo]; + var getSrc = input.GetGetter>(srcCol); + var ex = _columns[iinfo]; var mask = (1U << ex.HashBits) - 1; - var seed = ex.HashSeed; - var len = Infos[iinfo].TypeSrc.VectorSize; + var seed = ex.Seed; + var len = srcType.VectorSize; var src = default(VBuffer); T[] denseValues = null; - int expectedSrcLength = Infos[iinfo].TypeSrc.VectorSize; + int expectedSrcLength = srcType.VectorSize; HashLoop hasherDense = ex.Ordered ? hasherDenseOrdered : hasherDenseUnord; return @@ -668,6 +642,8 @@ private ValueGetter> ComposeGetterVecCoreFloat(IRow input, int }; } + #endregion + #region Core Hash functions, with and without index [MethodImpl(MethodImplOptions.AggressiveInlining)] private static uint HashCore(uint seed, ref ReadOnlyMemory value, uint mask) @@ -704,7 +680,7 @@ private static uint HashCore(uint seed, ref float value, int i, uint mask) if (float.IsNaN(value)) return 0; return (Hashing.MixHash(Hashing.MurmurRound(Hashing.MurmurRound(seed, (uint)i), - FloatUtils.GetBits(value == 0 ? 0: value))) & mask) + 1; + FloatUtils.GetBits(value == 0 ? 0 : value))) & mask) + 1; } [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -787,9 +763,9 @@ private static uint HashCore(uint seed, ulong value, int i, uint mask) hash = Hashing.MurmurRound(hash, hi); return (Hashing.MixHash(hash) & mask) + 1; } -#endregion Core Hash functions, with and without index + #endregion Core Hash functions, with and without index -#region Unordered Loop: ignore indices + #region Unordered Loop: ignore indices private static void HashUnord(int count, int[] indices, ReadOnlyMemory[] src, uint[] dst, uint seed, uint mask) { AssertValid(count, src, dst); @@ -847,7 +823,7 @@ private static void HashUnord(int count, int[] indices, double[] src, uint[] dst #endregion Unordered Loop: ignore indices -#region Dense Loop: ignore indices + #region Dense Loop: ignore indices private static void HashDense(int count, int[] indices, ReadOnlyMemory[] src, uint[] dst, uint seed, uint mask) { AssertValid(count, src, dst); @@ -902,9 +878,9 @@ private static void HashDense(int count, int[] indices, double[] src, uint[] dst for (int i = 0; i < count; i++) dst[i] = HashCore(seed, ref src[i], i, mask); } -#endregion Dense Loop: ignore indices + #endregion Dense Loop: ignore indices -#region Sparse Loop: use indices + #region Sparse Loop: use indices private static void HashSparse(int count, int[] indices, ReadOnlyMemory[] src, uint[] dst, uint seed, uint mask) { AssertValid(count, src, dst); @@ -992,7 +968,7 @@ private static void HashSparseUnord(int count, int[] indices, double[] src, uint } } -#endregion Sparse Loop: use indices + #endregion Sparse Loop: use indices [Conditional("DEBUG")] private static void AssertValid(int count, T[] src, uint[] dst) @@ -1002,27 +978,94 @@ private static void AssertValid(int count, T[] src, uint[] dst) Contracts.Assert(count <= Utils.Size(dst)); } - /// - /// This is a utility class to acquire and build the inverse hash to populate - /// KeyValues metadata. - /// + private sealed class Mapper : MapperBase + { + private sealed class ColInfo + { + public readonly string Name; + public readonly string Source; + public readonly ColumnType TypeSrc; + + public ColInfo(string name, string source, ColumnType type) + { + Name = name; + Source = source; + TypeSrc = type; + } + } + + private readonly ColumnType[] _types; + private readonly HashTransformer _parent; + + public Mapper(HashTransformer parent, ISchema inputSchema) + : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) + { + _parent = parent; + _types = new ColumnType[_parent._columns.Length]; + for (int i = 0; i < _types.Length; i++) + _types[i] = _parent.GetOutputType(inputSchema, _parent._columns[i]); + } + + public override RowMapperColumnInfo[] GetOutputColumns() + { + var result = new RowMapperColumnInfo[_parent.ColumnPairs.Length]; + for (int i = 0; i < _parent.ColumnPairs.Length; i++) + { + InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colIndex); + var colMetaInfo = new ColumnMetadataInfo(_parent.ColumnPairs[i].output); + + foreach (var type in InputSchema.GetMetadataTypes(colIndex).Where(x => x.Key == MetadataUtils.Kinds.SlotNames)) + Utils.MarshalInvoke(AddMetaGetter, type.Value.RawType, colMetaInfo, InputSchema, type.Key, type.Value, colIndex); + if (_parent._kvTypes != null && _parent._kvTypes[i] != null) + AddMetaKeyValues(i, colMetaInfo); + result[i] = new RowMapperColumnInfo(_parent.ColumnPairs[i].output, _types[i], colMetaInfo); + } + return result; + } + private void AddMetaKeyValues(int i, ColumnMetadataInfo colMetaInfo) + { + MetadataUtils.MetadataGetter>> getter = (int col, ref VBuffer> dst) => + { + _parent._keyValues[i].CopyTo(ref dst); + }; + var info = new MetadataInfo>>(_parent._kvTypes[i], getter); + colMetaInfo.Add(MetadataUtils.Kinds.KeyValues, info); + } + + private int AddMetaGetter(ColumnMetadataInfo colMetaInfo, ISchema schema, string kind, ColumnType ct, int originalCol) + { + MetadataUtils.MetadataGetter getter = (int col, ref T dst) => + { + // We don't care about 'col': this getter is specialized for a column 'originalCol', + // and 'col' in this case is the 'metadata kind index', not the column index. + schema.GetMetadata(kind, originalCol, ref dst); + }; + var info = new MetadataInfo(ct, getter); + colMetaInfo.Add(kind, info); + return 0; + } + + protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer) => _parent.GetGetterCore(input, iinfo, out disposer); + } + private abstract class InvertHashHelper { protected readonly IRow Row; private readonly bool _includeSlot; - private readonly ColInfo _info; - private readonly ColInfoEx _ex; + private readonly ColumnInfo _ex; + private readonly ColumnType _srcType; + private readonly int _srcCol; - private InvertHashHelper(IRow row, ColInfo info, ColInfoEx ex) + private InvertHashHelper(IRow row, ColumnInfo ex) { Contracts.AssertValue(row); - Contracts.AssertValue(info); - Row = row; - _info = info; + row.Schema.TryGetColumnIndex(ex.Input, out int srcCol); + _srcCol = srcCol; + _srcType = row.Schema.GetColumnType(srcCol); _ex = ex; // If this is a vector and ordered, then we must include the slot as part of the representation. - _includeSlot = _info.TypeSrc.IsVector && _ex.Ordered; + _includeSlot = _srcType.IsVector && _ex.Ordered; } /// @@ -1031,18 +1074,18 @@ private InvertHashHelper(IRow row, ColInfo info, ColInfoEx ex) /// the row. /// /// The input source row, from which the hashed values can be fetched - /// The column info, describing the source /// The extra column info - /// The number of input hashed values to accumulate per output hash value + /// The number of input hashed valuPres to accumulate per output hash value /// A hash getter, built on top of . - public static InvertHashHelper Create(IRow row, ColInfo info, ColInfoEx ex, int invertHashMaxCount, Delegate dstGetter) + public static InvertHashHelper Create(IRow row, ColumnInfo ex, int invertHashMaxCount, Delegate dstGetter) { - ColumnType typeSrc = info.TypeSrc; + row.Schema.TryGetColumnIndex(ex.Input, out int srcCol); + ColumnType typeSrc = row.Schema.GetColumnType(srcCol); Type t = typeSrc.IsVector ? (ex.Ordered ? typeof(ImplVecOrdered<>) : typeof(ImplVec<>)) : typeof(ImplOne<>); t = t.MakeGenericType(typeSrc.ItemType.RawType); - var consTypes = new Type[] { typeof(IRow), typeof(OneToOneTransformBase.ColInfo), typeof(ColInfoEx), typeof(int), typeof(Delegate) }; + var consTypes = new Type[] { typeof(IRow), typeof(ColumnInfo), typeof(int), typeof(Delegate) }; var constructorInfo = t.GetConstructor(consTypes); - return (InvertHashHelper)constructorInfo.Invoke(new object[] { row, info, ex, invertHashMaxCount, dstGetter }); + return (InvertHashHelper)constructorInfo.Invoke(new object[] { row, ex, invertHashMaxCount, dstGetter }); } /// @@ -1095,7 +1138,7 @@ public int GetHashCode(KeyValuePair obj) private IEqualityComparer GetSimpleComparer() { - Contracts.Assert(_info.TypeSrc.ItemType.RawType == typeof(T)); + Contracts.Assert(_srcType.ItemType.RawType == typeof(T)); if (typeof(T) == typeof(ReadOnlyMemory)) { // We are hashing twice, once to assign to the slot, and then again, @@ -1103,7 +1146,7 @@ private IEqualityComparer GetSimpleComparer() // same seed used to assign to a slot, or otherwise this per-slot hash // would have a lot of collisions. We ensure that we have different // hash function by inverting the seed's bits. - var c = new TextEqualityComparer(~_ex.HashSeed); + var c = new TextEqualityComparer(~_ex.Seed); return c as IEqualityComparer; } // I assume I hope correctly that the default .NET hash function for uint @@ -1117,11 +1160,10 @@ private abstract class Impl : InvertHashHelper { protected readonly InvertHashCollector Collector; - protected Impl(IRow row, ColInfo info, ColInfoEx ex, int invertHashMaxCount) - : base(row, info, ex) + protected Impl(IRow row, ColumnInfo ex, int invertHashMaxCount) + : base(row, ex) { Contracts.AssertValue(row); - Contracts.AssertValue(info); Contracts.AssertValue(ex); Collector = new InvertHashCollector(1 << ex.HashBits, invertHashMaxCount, GetTextMap(), GetComparer()); @@ -1129,7 +1171,7 @@ protected Impl(IRow row, ColInfo info, ColInfoEx ex, int invertHashMaxCount) protected virtual ValueMapper GetTextMap() { - return InvertHashUtils.GetSimpleMapper(Row.Schema, _info.Source); + return InvertHashUtils.GetSimpleMapper(Row.Schema, _srcCol); } protected virtual IEqualityComparer GetComparer() @@ -1151,10 +1193,10 @@ private sealed class ImplOne : Impl private T _value; private uint _hash; - public ImplOne(IRow row, OneToOneTransformBase.ColInfo info, ColInfoEx ex, int invertHashMaxCount, Delegate dstGetter) - : base(row, info, ex, invertHashMaxCount) + public ImplOne(IRow row, ColumnInfo ex, int invertHashMaxCount, Delegate dstGetter) + : base(row, ex, invertHashMaxCount) { - _srcGetter = Row.GetGetter(_info.Source); + _srcGetter = Row.GetGetter(_srcCol); _dstGetter = dstGetter as ValueGetter; Contracts.AssertValue(_dstGetter); } @@ -1185,10 +1227,10 @@ private sealed class ImplVec : Impl private VBuffer _value; private VBuffer _hash; - public ImplVec(IRow row, OneToOneTransformBase.ColInfo info, ColInfoEx ex, int invertHashMaxCount, Delegate dstGetter) - : base(row, info, ex, invertHashMaxCount) + public ImplVec(IRow row, ColumnInfo ex, int invertHashMaxCount, Delegate dstGetter) + : base(row, ex, invertHashMaxCount) { - _srcGetter = Row.GetGetter>(_info.Source); + _srcGetter = Row.GetGetter>(_srcCol); _dstGetter = dstGetter as ValueGetter>; Contracts.AssertValue(_dstGetter); } @@ -1214,17 +1256,17 @@ private sealed class ImplVecOrdered : Impl> private VBuffer _value; private VBuffer _hash; - public ImplVecOrdered(IRow row, OneToOneTransformBase.ColInfo info, ColInfoEx ex, int invertHashMaxCount, Delegate dstGetter) - : base(row, info, ex, invertHashMaxCount) + public ImplVecOrdered(IRow row, ColumnInfo ex, int invertHashMaxCount, Delegate dstGetter) + : base(row, ex, invertHashMaxCount) { - _srcGetter = Row.GetGetter>(_info.Source); + _srcGetter = Row.GetGetter>(_srcCol); _dstGetter = dstGetter as ValueGetter>; Contracts.AssertValue(_dstGetter); } protected override ValueMapper, StringBuilder> GetTextMap() { - var simple = InvertHashUtils.GetSimpleMapper(Row.Schema, _info.Source); + var simple = InvertHashUtils.GetSimpleMapper(Row.Schema, _srcCol); return InvertHashUtils.GetPairMapper(simple); } @@ -1255,5 +1297,74 @@ public override void Process() } } } + + /// + /// Estimator for + /// + public sealed class HashEstimator : IEstimator + { + public const int NumBitsMin = 1; + public const int NumBitsLim = 32; + + public static class Defaults + { + public const int HashBits = NumBitsLim - 1; + public const uint Seed = 314489979; + public const bool Ordered = false; + public const int InvertHash = 0; + } + + private readonly IHost _host; + private readonly HashTransformer.ColumnInfo[] _columns; + + public static bool IsColumnTypeValid(ColumnType type) => (type.ItemType.IsText || type.ItemType.IsKey || type.ItemType == NumberType.R4 || type.ItemType == NumberType.R8); + + internal const string ExpectedColumnType = "Expected Text, Key, Single or Double item type"; + + /// + /// Convinence constructor for simple one column case + /// + /// Host Environment. + /// Name of the output column. + /// Name of the column to be transformed. If this is null '' will be used. + /// Number of bits to hash into. Must be between 1 and 31, inclusive. + /// Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit. + public HashEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null, + int hashBits = Defaults.HashBits, int invertHash = Defaults.InvertHash) + : this(env, new HashTransformer.ColumnInfo(inputColumn, outputColumn ?? inputColumn, hashBits: hashBits, invertHash: invertHash)) + { + } + + /// Host Environment. + /// Description of dataset columns and how to process them. + public HashEstimator(IHostEnvironment env, params HashTransformer.ColumnInfo[] columns) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(nameof(HashEstimator)); + _columns = columns.ToArray(); + } + + public HashTransformer Fit(IDataView input) => new HashTransformer(_host, input, _columns); + + public SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + _host.CheckValue(inputSchema, nameof(inputSchema)); + var result = inputSchema.Columns.ToDictionary(x => x.Name); + foreach (var colInfo in _columns) + { + if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + if (!IsColumnTypeValid(col.ItemType)) + throw _host.ExceptParam(nameof(inputSchema), ExpectedColumnType); + var metadata = new List(); + if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.SlotNames, out var slotMeta)) + metadata.Add(slotMeta); + if (colInfo.InvertHash != 0) + metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false)); + result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, col.ItemType.IsVector ? SchemaShape.Column.VectorKind.Vector : SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true, new SchemaShape(metadata)); + } + return new SchemaShape(result.Values); + } + } } diff --git a/src/Microsoft.ML.Data/Transforms/TermEstimator.cs b/src/Microsoft.ML.Data/Transforms/TermEstimator.cs index 446140573c..9f8135252e 100644 --- a/src/Microsoft.ML.Data/Transforms/TermEstimator.cs +++ b/src/Microsoft.ML.Data/Transforms/TermEstimator.cs @@ -25,13 +25,13 @@ public static class Defaults /// Convenience constructor for public facing API. /// /// Host Environment. - /// Name of the output column. - /// Name of the column to be transformed. If this is null '' will be used. + /// Name of the output column. + /// Name of the column to be transformed. If this is null '' will be used. /// Maximum number of terms to keep per column when auto-training. /// How items should be ordered when vectorized. By default, they will be in the order encountered. /// If by value items are sorted according to their default comparison, e.g., text sorting will be case sensitive (e.g., 'A' then 'Z' then 'a'). - public TermEstimator(IHostEnvironment env, string name, string source = null, int maxNumTerms = Defaults.MaxNumTerms, TermTransform.SortOrder sort = Defaults.Sort) : - this(env, new TermTransform.ColumnInfo(name, source ?? name, maxNumTerms, sort)) + public TermEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null, int maxNumTerms = Defaults.MaxNumTerms, TermTransform.SortOrder sort = Defaults.Sort) : + this(env, new TermTransform.ColumnInfo(inputColumn, outputColumn ?? inputColumn, maxNumTerms, sort)) { } diff --git a/src/Microsoft.ML.Transforms/CategoricalHashTransform.cs b/src/Microsoft.ML.Transforms/CategoricalHashTransform.cs index 582586b9a0..5808a5e3e0 100644 --- a/src/Microsoft.ML.Transforms/CategoricalHashTransform.cs +++ b/src/Microsoft.ML.Transforms/CategoricalHashTransform.cs @@ -10,6 +10,7 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Transforms; [assembly: LoadableClass(CategoricalHashTransform.Summary, typeof(IDataTransform), typeof(CategoricalHashTransform), typeof(CategoricalHashTransform.Arguments), typeof(SignatureDataTransform), CategoricalHashTransform.UserName, "CategoricalHashTransform", "CatHashTransform", "CategoricalHash", "CatHash")] @@ -91,7 +92,7 @@ private static class Defaults } /// - /// This class is a merger of and + /// This class is a merger of and /// with join option removed /// public sealed class Arguments : TransformInputBase @@ -169,13 +170,13 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV throw h.ExceptUserArg(nameof(args.HashBits), "Number of bits must be between 1 and {0}", NumBitsLim - 1); // creating the Hash function - var hashArgs = new HashTransform.Arguments + var hashArgs = new HashTransformer.Arguments { HashBits = args.HashBits, Seed = args.Seed, Ordered = args.Ordered, InvertHash = args.InvertHash, - Column = new HashTransform.Column[args.Column.Length] + Column = new HashTransformer.Column[args.Column.Length] }; for (int i = 0; i < args.Column.Length; i++) { @@ -184,7 +185,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV throw h.ExceptUserArg(nameof(Column.Name)); h.Assert(!string.IsNullOrWhiteSpace(column.Name)); h.Assert(!string.IsNullOrWhiteSpace(column.Source)); - hashArgs.Column[i] = new HashTransform.Column + hashArgs.Column[i] = new HashTransformer.Column { HashBits = column.HashBits, Seed = column.Seed, @@ -198,7 +199,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV return CreateTransformCore( args.OutputKind, args.Column, args.Column.Select(col => col.OutputKind).ToList(), - new HashTransform(h, hashArgs, input), + HashTransformer.Create(h, hashArgs, input), h, args); } diff --git a/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs b/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs index 4341afbb92..d8a411d48c 100644 --- a/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs @@ -13,6 +13,7 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Transforms; [assembly: LoadableClass(WordBagTransform.Summary, typeof(IDataTransform), typeof(WordBagTransform), typeof(WordBagTransform.Arguments), typeof(SignatureDataTransform), "Word Bag Transform", "WordBagTransform", "WordBag")] @@ -474,7 +475,7 @@ public interface INgramExtractorFactory { /// /// Whether the extractor transform created by this factory uses the hashing trick - /// (by using or , for example). + /// (by using or , for example). /// bool UseHashingTrick { get; } diff --git a/src/Microsoft.ML.Transforms/Text/WordHashBagTransform.cs b/src/Microsoft.ML.Transforms/Text/WordHashBagTransform.cs index b51cccd220..1515801bcb 100644 --- a/src/Microsoft.ML.Transforms/Text/WordHashBagTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/WordHashBagTransform.cs @@ -11,6 +11,7 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Transforms; [assembly: LoadableClass(WordHashBagTransform.Summary, typeof(IDataTransform), typeof(WordHashBagTransform), typeof(WordHashBagTransform.Arguments), typeof(SignatureDataTransform), "Word Hash Bag Transform", "WordHashBagTransform", "WordHashBag")] @@ -266,7 +267,7 @@ public bool TryUnparse(StringBuilder sb) } /// - /// This class is a merger of and + /// This class is a merger of and /// , with the ordered option, /// the rehashUnigrams option and the allLength option removed. /// @@ -340,7 +341,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV List termCols = null; if (termLoaderArgs != null) termCols = new List(); - var hashColumns = new List(); + var hashColumns = new List(); var ngramHashColumns = new NgramHashTransform.Column[args.Column.Length]; var colCount = args.Column.Length; @@ -371,7 +372,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV } hashColumns.Add( - new HashTransform.Column + new HashTransformer.Column { Name = tmpName, Source = termLoaderArgs == null ? column.Source[isrc] : tmpName, @@ -435,7 +436,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV // Args for the Hash function with multiple columns var hashArgs = - new HashTransform.Arguments + new HashTransformer.Arguments { HashBits = 31, Seed = args.Seed, @@ -444,7 +445,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV InvertHash = args.InvertHash }; - view = new HashTransform(h, hashArgs, view); + view = HashTransformer.Create(h, hashArgs, view); // creating the NgramHash function var ngramHashArgs = diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs index 3abff3e560..e8bc11143e 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs @@ -2,18 +2,12 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; - -using System; -using System.Collections.Generic; -using System.IO; -using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Data.IO; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.TextAnalytics; +using Microsoft.ML.Transforms; +using System; using Xunit; +using Float = System.Single; namespace Microsoft.ML.Runtime.RunTests { @@ -82,14 +76,14 @@ private void TestHashTransformHelper(T[] data, uint[] results, NumberType typ builder.AddColumn("F1", type, data); var srcView = builder.GetDataView(); - HashTransform.Column col = new HashTransform.Column(); - col.Source = "F1"; + var col = new HashTransformer.Column(); + col.Name = "F1"; col.HashBits = 5; col.Seed = 42; - HashTransform.Arguments args = new HashTransform.Arguments(); - args.Column = new HashTransform.Column[] { col }; + var args = new HashTransformer.Arguments(); + args.Column = new HashTransformer.Column[] { col }; - var hashTransform = new HashTransform(Env, args, srcView); + var hashTransform = HashTransformer.Create(Env, args, srcView); using (var cursor = hashTransform.GetRowCursor(c => true)) { var resultGetter = cursor.GetGetter(1); @@ -120,14 +114,14 @@ private void TestHashTransformVectorHelper(VBuffer data, uint[][] results, private void TestHashTransformVectorHelper(ArrayDataViewBuilder builder, uint[][] results) { var srcView = builder.GetDataView(); - HashTransform.Column col = new HashTransform.Column(); - col.Source = "F1V"; + var col = new HashTransformer.Column(); + col.Name = "F1V"; col.HashBits = 5; col.Seed = 42; - HashTransform.Arguments args = new HashTransform.Arguments(); - args.Column = new HashTransform.Column[] { col }; + var args = new HashTransformer.Arguments(); + args.Column = new HashTransformer.Column[] { col }; - var hashTransform = new HashTransform(Env, args, srcView); + var hashTransform = HashTransformer.Create(Env, args, srcView); using (var cursor = hashTransform.GetRowCursor(c => true)) { var resultGetter = cursor.GetGetter>(1); diff --git a/test/Microsoft.ML.Tests/Transformers/HashTests.cs b/test/Microsoft.ML.Tests/Transformers/HashTests.cs new file mode 100644 index 0000000000..d52a800bb1 --- /dev/null +++ b/test/Microsoft.ML.Tests/Transformers/HashTests.cs @@ -0,0 +1,134 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Runtime.RunTests; +using Microsoft.ML.Runtime.Tools; +using Microsoft.ML.Transforms; +using System; +using System.IO; +using System.Linq; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.Tests.Transformers +{ + public class HashTests : TestDataPipeBase + { + public HashTests(ITestOutputHelper output) : base(output) + { + } + + private class TestClass + { + public float A; + public float B; + public float C; + } + + private class TestMeta + { + [VectorType(2)] + public float[] A; + public float B; + [VectorType(2)] + public double[] C; + public double D; + } + + [Fact] + public void HashWorkout() + { + var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; + + var dataView = ComponentCreation.CreateDataView(Env, data); + var pipe = new HashEstimator(Env, new[]{ + new HashTransformer.ColumnInfo("A", "HashA", hashBits:4, invertHash:-1), + new HashTransformer.ColumnInfo("B", "HashB", hashBits:3, ordered:true), + new HashTransformer.ColumnInfo("C", "HashC", seed:42), + new HashTransformer.ColumnInfo("A", "HashD"), + }); + + TestEstimatorCore(pipe, dataView); + Done(); + } + + [Fact] + public void TestMetadata() + { + + var data = new[] { + new TestMeta() { A=new float[2] { 3.5f, 2.5f}, B=1, C= new double[2] { 5.1f, 6.1f}, D= 7}, + new TestMeta() { A=new float[2] { 3.5f, 2.5f}, B=1, C= new double[2] { 5.1f, 6.1f}, D= 7}, + new TestMeta() { A=new float[2] { 3.5f, 2.5f}, B=1, C= new double[2] { 5.1f, 6.1f}, D= 7}}; + + + var dataView = ComponentCreation.CreateDataView(Env, data); + var pipe = new HashEstimator(Env, new[] { + new HashTransformer.ColumnInfo("A", "HashA", invertHash:1, hashBits:10), + new HashTransformer.ColumnInfo("A", "HashAUnlim", invertHash:-1, hashBits:10), + new HashTransformer.ColumnInfo("A", "HashAUnlimOrdered", invertHash:-1, hashBits:10, ordered:true) + }); + var result = pipe.Fit(dataView).Transform(dataView); + ValidateMetadata(result); + Done(); + } + + private void ValidateMetadata(IDataView result) + { + + Assert.True(result.Schema.TryGetColumnIndex("HashA", out int HashA)); + Assert.True(result.Schema.TryGetColumnIndex("HashAUnlim", out int HashAUnlim)); + Assert.True(result.Schema.TryGetColumnIndex("HashAUnlimOrdered", out int HashAUnlimOrdered)); + VBuffer> keys = default; + var types = result.Schema.GetMetadataTypes(HashA); + Assert.Equal(types.Select(x => x.Key), new string[1] { MetadataUtils.Kinds.KeyValues }); + result.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, HashA, ref keys); + Assert.True(keys.Length == 1024); + //REVIEW: This is weird. I specified invertHash to 1 so I expect only one value to be in key values, but i got two. + Assert.Equal(keys.Items().Select(x => x.Value.ToString()), new string[2] {"2.5", "3.5" }); + + types = result.Schema.GetMetadataTypes(HashAUnlim); + Assert.Equal(types.Select(x => x.Key), new string[1] { MetadataUtils.Kinds.KeyValues }); + result.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, HashA, ref keys); + Assert.True(keys.Length == 1024); + Assert.Equal(keys.Items().Select(x => x.Value.ToString()), new string[2] { "2.5", "3.5" }); + + types = result.Schema.GetMetadataTypes(HashAUnlimOrdered); + Assert.Equal(types.Select(x => x.Key), new string[1] { MetadataUtils.Kinds.KeyValues }); + result.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, HashA, ref keys); + Assert.True(keys.Length == 1024); + Assert.Equal(keys.Items().Select(x => x.Value.ToString()), new string[2] { "2.5", "3.5" }); + } + + [Fact] + public void TestCommandLine() + { + Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Hash{col=B:A} in=f:\2.txt" }), (int)0); + } + + [Fact] + public void TestOldSavingAndLoading() + { + var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; + var dataView = ComponentCreation.CreateDataView(Env, data); + var pipe = new HashEstimator(Env, new[]{ + new HashTransformer.ColumnInfo("A", "HashA", hashBits:4, invertHash:-1), + new HashTransformer.ColumnInfo("B", "HashB", hashBits:3, ordered:true), + new HashTransformer.ColumnInfo("C", "HashC", seed:42), + new HashTransformer.ColumnInfo("A", "HashD"), + }); + var result = pipe.Fit(dataView).Transform(dataView); + var resultRoles = new RoleMappedData(result); + using (var ms = new MemoryStream()) + { + TrainUtils.SaveModel(Env, Env.Start("saving"), ms, null, resultRoles); + ms.Position = 0; + var loadedView = ModelFileUtils.LoadTransforms(Env, dataView, ms); + } + } + } +}