diff --git a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs
index b0f2e12ef1..f6670394ee 100644
--- a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs
@@ -13,6 +13,7 @@
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Internal.Calibration;
using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Transforms;
[assembly: LoadableClass(typeof(CrossValidationCommand), typeof(CrossValidationCommand.Arguments), typeof(SignatureCommand),
"Cross Validation", CrossValidationCommand.LoadName)]
@@ -329,10 +330,7 @@ private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output
int inc = 0;
while (input.Schema.TryGetColumnIndex(stratificationColumn, out tmp))
stratificationColumn = string.Format("{0}_{1:000}", origStratCol, ++inc);
- var hashargs = new HashTransform.Arguments();
- hashargs.Column = new[] { new HashTransform.Column { Source = origStratCol, Name = stratificationColumn } };
- hashargs.HashBits = 30;
- output = new HashTransform(Host, hashargs, input);
+ output = new HashEstimator(Host, origStratCol, stratificationColumn, 30).Fit(input).Transform(input);
}
}
diff --git a/src/Microsoft.ML.Data/Transforms/HashTransform.cs b/src/Microsoft.ML.Data/Transforms/HashTransform.cs
index 8eec4e1581..a589121be6 100644
--- a/src/Microsoft.ML.Data/Transforms/HashTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/HashTransform.cs
@@ -2,24 +2,32 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System;
-using System.Collections.Generic;
-using System.Linq;
-using System.Runtime.CompilerServices;
-using System.Text;
+using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Model;
+using Microsoft.ML.Transforms;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Text;
-[assembly: LoadableClass(HashTransform.Summary, typeof(HashTransform), typeof(HashTransform.Arguments), typeof(SignatureDataTransform),
+[assembly: LoadableClass(HashTransformer.Summary, typeof(IDataTransform), typeof(HashTransformer), typeof(HashTransformer.Arguments), typeof(SignatureDataTransform),
"Hash Transform", "HashTransform", "Hash", DocName = "transform/HashTransform.md")]
-[assembly: LoadableClass(HashTransform.Summary, typeof(HashTransform), null, typeof(SignatureLoadDataTransform),
- "Hash Transform", HashTransform.LoaderSignature)]
+[assembly: LoadableClass(HashTransformer.Summary, typeof(IDataTransform), typeof(HashTransformer), null, typeof(SignatureLoadDataTransform),
+ "Hash Transform", HashTransformer.LoaderSignature)]
+
+[assembly: LoadableClass(HashTransformer.Summary, typeof(HashTransformer), null, typeof(SignatureLoadModel),
+ "Hash Transform", HashTransformer.LoaderSignature)]
-namespace Microsoft.ML.Runtime.Data
+[assembly: LoadableClass(typeof(IRowMapper), typeof(HashTransformer), null, typeof(SignatureLoadRowMapper),
+ "Hash Transform", HashTransformer.LoaderSignature)]
+
+namespace Microsoft.ML.Transforms
{
using Conditional = System.Diagnostics.ConditionalAttribute;
@@ -28,19 +36,8 @@ namespace Microsoft.ML.Runtime.Data
/// it hashes each slot separately.
/// It can hash either text values or key values.
///
- public sealed class HashTransform : OneToOneTransformBase, ITransformTemplate
+ public sealed class HashTransformer : OneToOneTransformerBase
{
- public const int NumBitsMin = 1;
- public const int NumBitsLim = 32;
-
- private static class Defaults
- {
- public const int HashBits = NumBitsLim - 1;
- public const uint Seed = 314489979;
- public const bool Ordered = false;
- public const int InvertHash = 0;
- }
-
public sealed class Arguments
{
[Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col",
@@ -49,18 +46,18 @@ public sealed class Arguments
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of bits to hash into. Must be between 1 and 31, inclusive",
ShortName = "bits", SortOrder = 2)]
- public int HashBits = Defaults.HashBits;
+ public int HashBits = HashEstimator.Defaults.HashBits;
[Argument(ArgumentType.AtMostOnce, HelpText = "Hashing seed")]
- public uint Seed = Defaults.Seed;
+ public uint Seed = HashEstimator.Defaults.Seed;
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether the position of each term should be included in the hash",
ShortName = "ord")]
- public bool Ordered = Defaults.Ordered;
+ public bool Ordered = HashEstimator.Defaults.Ordered;
[Argument(ArgumentType.AtMostOnce, HelpText = "Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit.",
ShortName = "ih")]
- public int InvertHash = Defaults.InvertHash;
+ public int InvertHash = HashEstimator.Defaults.InvertHash;
}
public sealed class Column : OneToOneColumn
@@ -95,8 +92,7 @@ protected override bool TryParse(string str)
// We accept N:B:S where N is the new column name, B is the number of bits,
// and S is source column names.
- string extra;
- if (!base.TryParse(str, out extra))
+ if (!base.TryParse(str, out string extra))
return false;
if (extra == null)
return true;
@@ -120,57 +116,72 @@ public bool TryUnparse(StringBuilder sb)
}
}
- private sealed class ColInfoEx
+ public sealed class ColumnInfo
{
+ public readonly string Input;
+ public readonly string Output;
public readonly int HashBits;
- public readonly uint HashSeed;
+ public readonly uint Seed;
public readonly bool Ordered;
+ public readonly int InvertHash;
- public ColInfoEx(Arguments args, Column col)
+ ///
+ /// Describes how the transformer handles one column pair.
+ ///
+ /// Name of input column.
+ /// Name of output column.
+ /// Number of bits to hash into. Must be between 1 and 31, inclusive.
+ /// Hashing seed.
+ /// Whether the position of each term should be included in the hash.
+ /// Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit.
+ public ColumnInfo(string input, string output,
+ int hashBits = HashEstimator.Defaults.HashBits,
+ uint seed = HashEstimator.Defaults.Seed,
+ bool ordered = HashEstimator.Defaults.Ordered,
+ int invertHash = HashEstimator.Defaults.InvertHash)
{
- HashBits = col.HashBits ?? args.HashBits;
- if (HashBits < NumBitsMin || HashBits >= NumBitsLim)
- throw Contracts.ExceptUserArg(nameof(args.HashBits), "Should be between {0} and {1} inclusive", NumBitsMin, NumBitsLim - 1);
- HashSeed = col.Seed ?? args.Seed;
- Ordered = col.Ordered ?? args.Ordered;
+ if (invertHash < -1)
+ throw Contracts.ExceptParam(nameof(invertHash), "Value too small, must be -1 or larger");
+ if (invertHash != 0 && hashBits >= 31)
+ throw Contracts.ExceptParam(nameof(hashBits), $"Cannot support invertHash for a {0} bit hash. 30 is the maximum possible.", hashBits);
+
+ Input = input;
+ Output = output;
+ HashBits = hashBits;
+ Seed = seed;
+ Ordered = ordered;
+ InvertHash = invertHash;
}
- public ColInfoEx(ModelLoadContext ctx)
+ internal ColumnInfo(string input, string output, ModelLoadContext ctx)
{
+ Input = input;
+ Output = output;
// *** Binary format ***
// int: HashBits
// uint: HashSeed
// byte: Ordered
-
HashBits = ctx.Reader.ReadInt32();
- Contracts.CheckDecode(NumBitsMin <= HashBits && HashBits < NumBitsLim);
-
- HashSeed = ctx.Reader.ReadUInt32();
+ Contracts.CheckDecode(HashEstimator.NumBitsMin <= HashBits && HashBits < HashEstimator.NumBitsLim);
+ Seed = ctx.Reader.ReadUInt32();
Ordered = ctx.Reader.ReadBoolByte();
}
- public void Save(ModelSaveContext ctx)
+ internal void Save(ModelSaveContext ctx)
{
// *** Binary format ***
// int: HashBits
// uint: HashSeed
// byte: Ordered
- Contracts.Assert(NumBitsMin <= HashBits && HashBits < NumBitsLim);
+ Contracts.Assert(HashEstimator.NumBitsMin <= HashBits && HashBits < HashEstimator.NumBitsLim);
ctx.Writer.Write(HashBits);
- ctx.Writer.Write(HashSeed);
+ ctx.Writer.Write(Seed);
ctx.Writer.WriteBoolByte(Ordered);
}
}
- private static string TestType(ColumnType type)
- {
- if (type.ItemType.IsText || type.ItemType.IsKey || type.ItemType == NumberType.R4 || type.ItemType == NumberType.R8)
- return null;
- return "Expected Text, Key, Single or Double item type";
- }
-
private const string RegistrationName = "Hash";
internal const string Summary = "Converts column values into hashes. This transform accepts text and keys as inputs. It works on single- and vector-valued columns, "
@@ -188,109 +199,81 @@ private static VersionInfo GetVersionInfo()
loaderSignature: LoaderSignature);
}
- private readonly ColInfoEx[] _exes;
- private readonly ColumnType[] _types;
-
+ private readonly ColumnInfo[] _columns;
private readonly VBuffer>[] _keyValues;
private readonly ColumnType[] _kvTypes;
- public static HashTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
+ protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol)
{
- Contracts.CheckValue(env, nameof(env));
- var h = env.Register(RegistrationName);
- h.CheckValue(ctx, nameof(ctx));
- h.CheckValue(input, nameof(input));
- ctx.CheckAtModel(GetVersionInfo());
- return h.Apply("Loading Model", ch => new HashTransform(h, ctx, input));
+ var type = inputSchema.GetColumnType(srcCol);
+ if (!HashEstimator.IsColumnTypeValid(type))
+ throw Host.ExceptParam(nameof(inputSchema), HashEstimator.ExpectedColumnType);
}
- private HashTransform(IHost host, ModelLoadContext ctx, IDataView input)
- : base(host, ctx, input, TestType)
+ private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns)
{
- Host.AssertValue(ctx);
-
- // *** Binary format ***
- //
- // Exes
-
- Host.AssertNonEmpty(Infos);
- _exes = new ColInfoEx[Infos.Length];
- for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
- _exes[iinfo] = new ColInfoEx(ctx);
-
- _types = InitColumnTypes();
-
- TextModelHelper.LoadAll(Host, ctx, Infos.Length, out _keyValues, out _kvTypes);
- SetMetadata();
+ Contracts.CheckNonEmpty(columns, nameof(columns));
+ return columns.Select(x => (x.Input, x.Output)).ToArray();
}
- public override void Save(ModelSaveContext ctx)
+ private ColumnType GetOutputType(ISchema inputSchema, ColumnInfo column)
{
- Host.CheckValue(ctx, nameof(ctx));
- ctx.CheckAtModel();
- ctx.SetVersionInfo(GetVersionInfo());
-
- //
- //
- // Exes
-
- SaveBase(ctx);
- Host.Assert(_exes.Length == Infos.Length);
- for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
- _exes[iinfo].Save(ctx);
-
- TextModelHelper.SaveAll(Host, ctx, Infos.Length, _keyValues);
+ var keyCount = column.HashBits < 31 ? 1 << column.HashBits : 0;
+ inputSchema.TryGetColumnIndex(column.Input, out int srcCol);
+ var itemType = new KeyType(DataKind.U4, 0, keyCount, keyCount > 0);
+ var srcType = inputSchema.GetColumnType(srcCol);
+ if (!srcType.IsVector)
+ return itemType;
+ else
+ return new VectorType(itemType, srcType.VectorSize);
}
///
- /// Convenience constructor for public facing API.
+ /// Constructor for case where you don't need to 'train' transform on data, e.g. InvertHash for all columns set to zero.
///
/// Host Environment.
- /// Input . This is the output from previous transform or loader.
- /// Name of the output column.
- /// Name of the column to be transformed. If this is null '' will be used.
- /// Number of bits to hash into. Must be between 1 and 31, inclusive.
- /// Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit.
- public HashTransform(IHostEnvironment env,
- IDataView input,
- string name,
- string source = null,
- int hashBits = Defaults.HashBits,
- int invertHash = Defaults.InvertHash)
- : this(env, new Arguments() {
- Column = new[] { new Column() { Source = source ?? name, Name = name } },
- HashBits = hashBits, InvertHash = invertHash }, input)
+ /// Description of dataset columns and how to process them.
+ public HashTransformer(IHostEnvironment env, ColumnInfo[] columns) :
+ base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns))
{
+ _columns = columns.ToArray();
+ foreach (var column in _columns)
+ {
+ if (column.InvertHash != 0)
+ throw Host.ExceptParam(nameof(columns), $"Found colunm with {nameof(column.InvertHash)} set to non zero value, please use { nameof(HashEstimator)} instead");
+ }
}
- public HashTransform(IHostEnvironment env, Arguments args, IDataView input)
- : base(Contracts.CheckRef(env, nameof(env)), RegistrationName, env.CheckRef(args, nameof(args)).Column,
- input, TestType)
+ internal HashTransformer(IHostEnvironment env, IDataView input, ColumnInfo[] columns) :
+ base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns))
{
- if (args.HashBits < NumBitsMin || args.HashBits >= NumBitsLim)
- throw Host.ExceptUserArg(nameof(args.HashBits), "hashBits should be between {0} and {1} inclusive", NumBitsMin, NumBitsLim - 1);
-
- _exes = new ColInfoEx[Infos.Length];
+ _columns = columns.ToArray();
+ var types = new ColumnType[_columns.Length];
List invertIinfos = null;
List invertHashMaxCounts = null;
- for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
+ HashSet sourceColumnsForInvertHash = new HashSet();
+ for (int i = 0; i < _columns.Length; i++)
{
- _exes[iinfo] = new ColInfoEx(args, args.Column[iinfo]);
- int invertHashMaxCount = GetAndVerifyInvertHashMaxCount(args, args.Column[iinfo], _exes[iinfo]);
+ if (!input.Schema.TryGetColumnIndex(ColumnPairs[i].input, out int srcCol))
+ throw Host.ExceptSchemaMismatch(nameof(input), "input", ColumnPairs[i].input);
+ CheckInputColumn(input.Schema, i, srcCol);
+
+ types[i] = GetOutputType(input.Schema, _columns[i]);
+ int invertHashMaxCount;
+ if (_columns[i].InvertHash == -1)
+ invertHashMaxCount = int.MaxValue;
+ else
+ invertHashMaxCount = _columns[i].InvertHash;
if (invertHashMaxCount > 0)
{
- Utils.Add(ref invertIinfos, iinfo);
+ Utils.Add(ref invertIinfos, i);
Utils.Add(ref invertHashMaxCounts, invertHashMaxCount);
+ sourceColumnsForInvertHash.Add(srcCol);
}
}
-
- _types = InitColumnTypes();
-
- if (Utils.Size(invertIinfos) > 0)
+ if (Utils.Size(sourceColumnsForInvertHash) > 0)
{
- // Build the invert hashes for all columns for which it was requested.
- var srcs = new HashSet(invertIinfos.Select(i => Infos[i].Source));
- using (IRowCursor srcCursor = input.GetRowCursor(srcs.Contains))
+ using (IRowCursor srcCursor = input.GetRowCursor(sourceColumnsForInvertHash.Contains))
{
using (var ch = Host.Start("Invert hash building"))
{
@@ -299,154 +282,146 @@ public HashTransform(IHostEnvironment env, Arguments args, IDataView input)
for (int i = 0; i < helpers.Length; ++i)
{
int iinfo = invertIinfos[i];
- Host.Assert(_types[iinfo].ItemType.KeyCount > 0);
- var dstGetter = GetGetterCore(ch, srcCursor, iinfo, out disposer);
+ Host.Assert(types[iinfo].ItemType.KeyCount > 0);
+ var dstGetter = GetGetterCore(srcCursor, iinfo, out disposer);
Host.Assert(disposer == null);
- var ex = _exes[iinfo];
+ var ex = _columns[iinfo];
var maxCount = invertHashMaxCounts[i];
- helpers[i] = InvertHashHelper.Create(srcCursor, Infos[iinfo], ex, maxCount, dstGetter);
+ helpers[i] = InvertHashHelper.Create(srcCursor, ex, maxCount, dstGetter);
}
while (srcCursor.MoveNext())
{
for (int i = 0; i < helpers.Length; ++i)
helpers[i].Process();
}
- _keyValues = new VBuffer>[_exes.Length];
- _kvTypes = new ColumnType[_exes.Length];
+ _keyValues = new VBuffer>[_columns.Length];
+ _kvTypes = new ColumnType[_columns.Length];
for (int i = 0; i < helpers.Length; ++i)
{
_keyValues[invertIinfos[i]] = helpers[i].GetKeyValuesMetadata();
- Host.Assert(_keyValues[invertIinfos[i]].Length == _types[invertIinfos[i]].ItemType.KeyCount);
+ Host.Assert(_keyValues[invertIinfos[i]].Length == types[invertIinfos[i]].ItemType.KeyCount);
_kvTypes[invertIinfos[i]] = new VectorType(TextType.Instance, _keyValues[invertIinfos[i]].Length);
}
ch.Done();
}
}
}
- SetMetadata();
}
- ///
- /// Re-apply constructor.
- ///
- private HashTransform(IHostEnvironment env, HashTransform transform, IDataView newSource)
- : base(env, RegistrationName, transform, newSource, TestType)
+ private Delegate GetGetterCore(IRow input, int iinfo, out Action disposer)
{
- _exes = transform._exes;
- _types = InitColumnTypes();
- _keyValues = transform._keyValues;
- _kvTypes = transform._kvTypes;
- SetMetadata();
+ Host.AssertValue(input);
+ Host.Assert(0 <= iinfo && iinfo < _columns.Length);
+ disposer = null;
+ input.Schema.TryGetColumnIndex(_columns[iinfo].Input, out int srcCol);
+ var srcType = input.Schema.GetColumnType(srcCol);
+ if (!srcType.IsVector)
+ return ComposeGetterOne(input, iinfo, srcCol, srcType);
+ return ComposeGetterVec(input, iinfo, srcCol, srcType);
}
- public IDataTransform ApplyToData(IHostEnvironment env, IDataView newSource)
- {
- return new HashTransform(env, this, newSource);
- }
+ protected override IRowMapper MakeRowMapper(ISchema schema) => new Mapper(this, schema);
- private static int GetAndVerifyInvertHashMaxCount(Arguments args, Column col, ColInfoEx ex)
+ // Factory method for SignatureLoadModel.
+ private static HashTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
{
- var invertHashMaxCount = col.InvertHash ?? args.InvertHash;
- if (invertHashMaxCount != 0)
- {
- if (invertHashMaxCount == -1)
- invertHashMaxCount = int.MaxValue;
- Contracts.CheckUserArg(invertHashMaxCount > 0, nameof(args.InvertHash), "Value too small, must be -1 or larger");
- // If the bits is 31 or higher, we can't declare a KeyValues of the appropriate length,
- // this requiring a VBuffer of length 1u << 31 which exceeds int.MaxValue.
- if (ex.HashBits >= 31)
- throw Contracts.ExceptUserArg(nameof(args.InvertHash), "Cannot support invertHash for a {0} bit hash. 30 is the maximum possible.", ex.HashBits);
- }
- return invertHashMaxCount;
- }
+ Contracts.CheckValue(env, nameof(env));
+ var host = env.Register(RegistrationName);
- private ColumnType[] InitColumnTypes()
- {
- var types = new ColumnType[Infos.Length];
- for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
- {
- var keyCount = _exes[iinfo].HashBits < 31 ? 1 << _exes[iinfo].HashBits : 0;
- var itemType = new KeyType(DataKind.U4, 0, keyCount, keyCount > 0);
- if (!Infos[iinfo].TypeSrc.IsVector)
- types[iinfo] = itemType;
- else
- types[iinfo] = new VectorType(itemType, Infos[iinfo].TypeSrc.VectorSize);
- }
- return types;
- }
+ host.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel(GetVersionInfo());
- protected override ColumnType GetColumnTypeCore(int iinfo)
- {
- Host.Check(0 <= iinfo & iinfo < Infos.Length);
- return _types[iinfo];
+ return new HashTransformer(host, ctx);
}
- private void SetMetadata()
+ private HashTransformer(IHost host, ModelLoadContext ctx)
+ : base(host, ctx)
{
- var md = Metadata;
- for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
- {
- using (var bldr = md.BuildMetadata(iinfo, Source.Schema, Infos[iinfo].Source,
- MetadataUtils.Kinds.SlotNames))
- {
- if (_kvTypes != null && _kvTypes[iinfo] != null)
- bldr.AddGetter>>(MetadataUtils.Kinds.KeyValues, _kvTypes[iinfo], GetTerms);
- }
- }
- md.Seal();
+ var columnsLength = ColumnPairs.Length;
+ _columns = new ColumnInfo[columnsLength];
+ for (int i = 0; i < columnsLength; i++)
+ _columns[i] = new ColumnInfo(ColumnPairs[i].input, ColumnPairs[i].output, ctx);
+ TextModelHelper.LoadAll(Host, ctx, columnsLength, out _keyValues, out _kvTypes);
}
- private void GetTerms(int iinfo, ref VBuffer> dst)
+ public override void Save(ModelSaveContext ctx)
{
- Host.Assert(0 <= iinfo && iinfo < Infos.Length);
- Host.Assert(Utils.Size(_keyValues) == Infos.Length);
- Host.Assert(_keyValues[iinfo].Length > 0);
- _keyValues[iinfo].CopyTo(ref dst);
+ Host.CheckValue(ctx, nameof(ctx));
+
+ ctx.CheckAtModel();
+ ctx.SetVersionInfo(GetVersionInfo());
+
+ SaveColumns(ctx);
+
+ //
+ //
+ //
+ Host.Assert(_columns.Length == ColumnPairs.Length);
+ foreach (var col in _columns)
+ col.Save(ctx);
+
+ TextModelHelper.SaveAll(Host, ctx, _columns.Length, _keyValues);
}
- protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer)
+ // Factory method for SignatureLoadDataTransform.
+ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
+ => Create(env, ctx).MakeDataTransform(input);
+
+ // Factory method for SignatureLoadRowMapper.
+ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema)
+ => Create(env, ctx).MakeRowMapper(inputSchema);
+
+ // Factory method for SignatureDataTransform.
+ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
{
- Host.AssertValueOrNull(ch);
- Host.AssertValue(input);
- Host.Assert(0 <= iinfo && iinfo < Infos.Length);
- disposer = null;
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(args, nameof(args));
+ env.CheckValue(input, nameof(input));
- if (!Infos[iinfo].TypeSrc.IsVector)
- return ComposeGetterOne(input, iinfo);
- return ComposeGetterVec(input, iinfo);
+ env.CheckValue(args.Column, nameof(args.Column));
+ var cols = new ColumnInfo[args.Column.Length];
+ for (int i = 0; i < cols.Length; i++)
+ {
+ var item = args.Column[i];
+ var kind = item.InvertHash ?? args.InvertHash;
+ cols[i] = new ColumnInfo(item.Source ?? item.Name,
+ item.Name,
+ item.HashBits ?? args.HashBits,
+ item.Seed ?? args.Seed,
+ item.Ordered ?? args.Ordered,
+ item.InvertHash ?? args.InvertHash);
+ };
+ return new HashTransformer(env, input, cols).MakeDataTransform(input);
}
- ///
- /// Getter generator for single valued inputs
- ///
- private ValueGetter ComposeGetterOne(IRow input, int iinfo)
+ #region Getters
+ private ValueGetter ComposeGetterOne(IRow input, int iinfo, int srcCol, ColumnType srcType)
{
- var colType = Infos[iinfo].TypeSrc;
- Host.Assert(colType.IsText || colType.IsKey || colType == NumberType.R4 || colType == NumberType.R8);
+ Host.Assert(srcType.IsText || srcType.IsKey || srcType == NumberType.R4 || srcType == NumberType.R8);
- var mask = (1U << _exes[iinfo].HashBits) - 1;
- uint seed = _exes[iinfo].HashSeed;
+ var mask = (1U << _columns[iinfo].HashBits) - 1;
+ uint seed = _columns[iinfo].Seed;
// In case of single valued input column, hash in 0 for the slot index.
- if (_exes[iinfo].Ordered)
+ if (_columns[iinfo].Ordered)
seed = Hashing.MurmurRound(seed, 0);
- switch (colType.RawKind)
+ switch (srcType.RawKind)
{
- case DataKind.Text:
- return ComposeGetterOneCore(GetSrcGetter>(input, iinfo), seed, mask);
- case DataKind.U1:
- return ComposeGetterOneCore(GetSrcGetter(input, iinfo), seed, mask);
- case DataKind.U2:
- return ComposeGetterOneCore(GetSrcGetter(input, iinfo), seed, mask);
- case DataKind.U4:
- return ComposeGetterOneCore(GetSrcGetter(input, iinfo), seed, mask);
- case DataKind.R4:
- return ComposeGetterOneCore(GetSrcGetter(input, iinfo), seed, mask);
- case DataKind.R8:
- return ComposeGetterOneCore(GetSrcGetter(input, iinfo), seed, mask);
- default:
- Host.Assert(colType.RawKind == DataKind.U8);
- return ComposeGetterOneCore(GetSrcGetter(input, iinfo), seed, mask);
+ case DataKind.Text:
+ return ComposeGetterOneCore(input.GetGetter>(srcCol), seed, mask);
+ case DataKind.U1:
+ return ComposeGetterOneCore(input.GetGetter(srcCol), seed, mask);
+ case DataKind.U2:
+ return ComposeGetterOneCore(input.GetGetter(srcCol), seed, mask);
+ case DataKind.U4:
+ return ComposeGetterOneCore(input.GetGetter(srcCol), seed, mask);
+ case DataKind.R4:
+ return ComposeGetterOneCore(input.GetGetter(srcCol), seed, mask);
+ case DataKind.R8:
+ return ComposeGetterOneCore(input.GetGetter(srcCol), seed, mask);
+ default:
+ Host.Assert(srcType.RawKind == DataKind.U8);
+ return ComposeGetterOneCore(input.GetGetter(srcCol), seed, mask);
}
}
@@ -537,43 +512,42 @@ private ValueGetter ComposeGetterOneCore(ValueGetter getSrc, uint
// them (either with their index or without) into dst. Additionally it fills in zero hashes in the rest of dst elements.
private delegate void HashLoopWithZeroHash(int count, int[] indices, TSrc[] src, uint[] dst, int dstCount, uint seed, uint mask);
- private ValueGetter> ComposeGetterVec(IRow input, int iinfo)
+ private ValueGetter> ComposeGetterVec(IRow input, int iinfo, int srcCol, ColumnType srcType)
{
- var colType = Infos[iinfo].TypeSrc;
- Host.Assert(colType.IsVector);
- Host.Assert(colType.ItemType.IsText || colType.ItemType.IsKey || colType.ItemType == NumberType.R4 || colType.ItemType == NumberType.R8);
+ Host.Assert(srcType.IsVector);
+ Host.Assert(srcType.ItemType.IsText || srcType.ItemType.IsKey || srcType.ItemType == NumberType.R4 || srcType.ItemType == NumberType.R8);
- switch (colType.ItemType.RawKind)
+ switch (srcType.ItemType.RawKind)
{
- case DataKind.Text:
- return ComposeGetterVecCore>(input, iinfo, HashUnord, HashDense, HashSparse);
- case DataKind.U1:
- return ComposeGetterVecCore(input, iinfo, HashUnord, HashDense, HashSparse);
- case DataKind.U2:
- return ComposeGetterVecCore(input, iinfo, HashUnord, HashDense, HashSparse);
- case DataKind.U4:
- return ComposeGetterVecCore(input, iinfo, HashUnord, HashDense, HashSparse);
- case DataKind.R4:
- return ComposeGetterVecCoreFloat(input, iinfo, HashSparseUnord, HashUnord, HashDense);
- case DataKind.R8:
- return ComposeGetterVecCoreFloat(input, iinfo, HashSparseUnord, HashUnord, HashDense);
- default:
- Host.Assert(colType.ItemType.RawKind == DataKind.U8);
- return ComposeGetterVecCore(input, iinfo, HashUnord, HashDense, HashSparse);
+ case DataKind.Text:
+ return ComposeGetterVecCore>(input, iinfo, srcCol, srcType, HashUnord, HashDense, HashSparse);
+ case DataKind.U1:
+ return ComposeGetterVecCore(input, iinfo, srcCol, srcType, HashUnord, HashDense, HashSparse);
+ case DataKind.U2:
+ return ComposeGetterVecCore(input, iinfo, srcCol, srcType, HashUnord, HashDense, HashSparse);
+ case DataKind.U4:
+ return ComposeGetterVecCore(input, iinfo, srcCol, srcType, HashUnord, HashDense, HashSparse);
+ case DataKind.R4:
+ return ComposeGetterVecCoreFloat(input, iinfo, srcCol, srcType, HashSparseUnord, HashUnord, HashDense);
+ case DataKind.R8:
+ return ComposeGetterVecCoreFloat(input, iinfo, srcCol, srcType, HashSparseUnord, HashUnord, HashDense);
+ default:
+ Host.Assert(srcType.ItemType.RawKind == DataKind.U8);
+ return ComposeGetterVecCore(input, iinfo, srcCol, srcType, HashUnord, HashDense, HashSparse);
}
}
- private ValueGetter> ComposeGetterVecCore(IRow input, int iinfo,
+ private ValueGetter> ComposeGetterVecCore(IRow input, int iinfo, int srcCol, ColumnType srcType,
HashLoop hasherUnord, HashLoop hasherDense, HashLoop hasherSparse)
{
- Host.Assert(Infos[iinfo].TypeSrc.IsVector);
- Host.Assert(Infos[iinfo].TypeSrc.ItemType.RawType == typeof(T));
+ Host.Assert(srcType.IsVector);
+ Host.Assert(srcType.ItemType.RawType == typeof(T));
- var getSrc = GetSrcGetter>(input, iinfo);
- var ex = _exes[iinfo];
+ var getSrc = input.GetGetter>(srcCol);
+ var ex = _columns[iinfo];
var mask = (1U << ex.HashBits) - 1;
- var seed = ex.HashSeed;
- var len = Infos[iinfo].TypeSrc.VectorSize;
+ var seed = ex.Seed;
+ var len = srcType.VectorSize;
var src = default(VBuffer);
if (!ex.Ordered)
@@ -612,20 +586,20 @@ private ValueGetter> ComposeGetterVecCore(IRow input, int iinfo
};
}
- private ValueGetter> ComposeGetterVecCoreFloat(IRow input, int iinfo,
+ private ValueGetter> ComposeGetterVecCoreFloat(IRow input, int iinfo, int srcCol, ColumnType srcType,
HashLoopWithZeroHash hasherSparseUnord, HashLoop hasherDenseUnord, HashLoop hasherDenseOrdered)
{
- Host.Assert(Infos[iinfo].TypeSrc.IsVector);
- Host.Assert(Infos[iinfo].TypeSrc.ItemType.RawType == typeof(T));
+ Host.Assert(srcType.IsVector);
+ Host.Assert(srcType.ItemType.RawType == typeof(T));
- var getSrc = GetSrcGetter>(input, iinfo);
- var ex = _exes[iinfo];
+ var getSrc = input.GetGetter>(srcCol);
+ var ex = _columns[iinfo];
var mask = (1U << ex.HashBits) - 1;
- var seed = ex.HashSeed;
- var len = Infos[iinfo].TypeSrc.VectorSize;
+ var seed = ex.Seed;
+ var len = srcType.VectorSize;
var src = default(VBuffer);
T[] denseValues = null;
- int expectedSrcLength = Infos[iinfo].TypeSrc.VectorSize;
+ int expectedSrcLength = srcType.VectorSize;
HashLoop hasherDense = ex.Ordered ? hasherDenseOrdered : hasherDenseUnord;
return
@@ -668,6 +642,8 @@ private ValueGetter> ComposeGetterVecCoreFloat(IRow input, int
};
}
+ #endregion
+
#region Core Hash functions, with and without index
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static uint HashCore(uint seed, ref ReadOnlyMemory value, uint mask)
@@ -704,7 +680,7 @@ private static uint HashCore(uint seed, ref float value, int i, uint mask)
if (float.IsNaN(value))
return 0;
return (Hashing.MixHash(Hashing.MurmurRound(Hashing.MurmurRound(seed, (uint)i),
- FloatUtils.GetBits(value == 0 ? 0: value))) & mask) + 1;
+ FloatUtils.GetBits(value == 0 ? 0 : value))) & mask) + 1;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
@@ -787,9 +763,9 @@ private static uint HashCore(uint seed, ulong value, int i, uint mask)
hash = Hashing.MurmurRound(hash, hi);
return (Hashing.MixHash(hash) & mask) + 1;
}
-#endregion Core Hash functions, with and without index
+ #endregion Core Hash functions, with and without index
-#region Unordered Loop: ignore indices
+ #region Unordered Loop: ignore indices
private static void HashUnord(int count, int[] indices, ReadOnlyMemory[] src, uint[] dst, uint seed, uint mask)
{
AssertValid(count, src, dst);
@@ -847,7 +823,7 @@ private static void HashUnord(int count, int[] indices, double[] src, uint[] dst
#endregion Unordered Loop: ignore indices
-#region Dense Loop: ignore indices
+ #region Dense Loop: ignore indices
private static void HashDense(int count, int[] indices, ReadOnlyMemory[] src, uint[] dst, uint seed, uint mask)
{
AssertValid(count, src, dst);
@@ -902,9 +878,9 @@ private static void HashDense(int count, int[] indices, double[] src, uint[] dst
for (int i = 0; i < count; i++)
dst[i] = HashCore(seed, ref src[i], i, mask);
}
-#endregion Dense Loop: ignore indices
+ #endregion Dense Loop: ignore indices
-#region Sparse Loop: use indices
+ #region Sparse Loop: use indices
private static void HashSparse(int count, int[] indices, ReadOnlyMemory[] src, uint[] dst, uint seed, uint mask)
{
AssertValid(count, src, dst);
@@ -992,7 +968,7 @@ private static void HashSparseUnord(int count, int[] indices, double[] src, uint
}
}
-#endregion Sparse Loop: use indices
+ #endregion Sparse Loop: use indices
[Conditional("DEBUG")]
private static void AssertValid(int count, T[] src, uint[] dst)
@@ -1002,27 +978,94 @@ private static void AssertValid(int count, T[] src, uint[] dst)
Contracts.Assert(count <= Utils.Size(dst));
}
- ///
- /// This is a utility class to acquire and build the inverse hash to populate
- /// KeyValues metadata.
- ///
+ private sealed class Mapper : MapperBase
+ {
+ private sealed class ColInfo
+ {
+ public readonly string Name;
+ public readonly string Source;
+ public readonly ColumnType TypeSrc;
+
+ public ColInfo(string name, string source, ColumnType type)
+ {
+ Name = name;
+ Source = source;
+ TypeSrc = type;
+ }
+ }
+
+ private readonly ColumnType[] _types;
+ private readonly HashTransformer _parent;
+
+ public Mapper(HashTransformer parent, ISchema inputSchema)
+ : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema)
+ {
+ _parent = parent;
+ _types = new ColumnType[_parent._columns.Length];
+ for (int i = 0; i < _types.Length; i++)
+ _types[i] = _parent.GetOutputType(inputSchema, _parent._columns[i]);
+ }
+
+ public override RowMapperColumnInfo[] GetOutputColumns()
+ {
+ var result = new RowMapperColumnInfo[_parent.ColumnPairs.Length];
+ for (int i = 0; i < _parent.ColumnPairs.Length; i++)
+ {
+ InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colIndex);
+ var colMetaInfo = new ColumnMetadataInfo(_parent.ColumnPairs[i].output);
+
+ foreach (var type in InputSchema.GetMetadataTypes(colIndex).Where(x => x.Key == MetadataUtils.Kinds.SlotNames))
+ Utils.MarshalInvoke(AddMetaGetter, type.Value.RawType, colMetaInfo, InputSchema, type.Key, type.Value, colIndex);
+ if (_parent._kvTypes != null && _parent._kvTypes[i] != null)
+ AddMetaKeyValues(i, colMetaInfo);
+ result[i] = new RowMapperColumnInfo(_parent.ColumnPairs[i].output, _types[i], colMetaInfo);
+ }
+ return result;
+ }
+ private void AddMetaKeyValues(int i, ColumnMetadataInfo colMetaInfo)
+ {
+ MetadataUtils.MetadataGetter>> getter = (int col, ref VBuffer> dst) =>
+ {
+ _parent._keyValues[i].CopyTo(ref dst);
+ };
+ var info = new MetadataInfo>>(_parent._kvTypes[i], getter);
+ colMetaInfo.Add(MetadataUtils.Kinds.KeyValues, info);
+ }
+
+ private int AddMetaGetter(ColumnMetadataInfo colMetaInfo, ISchema schema, string kind, ColumnType ct, int originalCol)
+ {
+ MetadataUtils.MetadataGetter getter = (int col, ref T dst) =>
+ {
+ // We don't care about 'col': this getter is specialized for a column 'originalCol',
+ // and 'col' in this case is the 'metadata kind index', not the column index.
+ schema.GetMetadata(kind, originalCol, ref dst);
+ };
+ var info = new MetadataInfo(ct, getter);
+ colMetaInfo.Add(kind, info);
+ return 0;
+ }
+
+ protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer) => _parent.GetGetterCore(input, iinfo, out disposer);
+ }
+
private abstract class InvertHashHelper
{
protected readonly IRow Row;
private readonly bool _includeSlot;
- private readonly ColInfo _info;
- private readonly ColInfoEx _ex;
+ private readonly ColumnInfo _ex;
+ private readonly ColumnType _srcType;
+ private readonly int _srcCol;
- private InvertHashHelper(IRow row, ColInfo info, ColInfoEx ex)
+ private InvertHashHelper(IRow row, ColumnInfo ex)
{
Contracts.AssertValue(row);
- Contracts.AssertValue(info);
-
Row = row;
- _info = info;
+ row.Schema.TryGetColumnIndex(ex.Input, out int srcCol);
+ _srcCol = srcCol;
+ _srcType = row.Schema.GetColumnType(srcCol);
_ex = ex;
// If this is a vector and ordered, then we must include the slot as part of the representation.
- _includeSlot = _info.TypeSrc.IsVector && _ex.Ordered;
+ _includeSlot = _srcType.IsVector && _ex.Ordered;
}
///
@@ -1031,18 +1074,18 @@ private InvertHashHelper(IRow row, ColInfo info, ColInfoEx ex)
/// the row.
///
/// The input source row, from which the hashed values can be fetched
- /// The column info, describing the source
/// The extra column info
- /// The number of input hashed values to accumulate per output hash value
+ /// The number of input hashed valuPres to accumulate per output hash value
/// A hash getter, built on top of .
- public static InvertHashHelper Create(IRow row, ColInfo info, ColInfoEx ex, int invertHashMaxCount, Delegate dstGetter)
+ public static InvertHashHelper Create(IRow row, ColumnInfo ex, int invertHashMaxCount, Delegate dstGetter)
{
- ColumnType typeSrc = info.TypeSrc;
+ row.Schema.TryGetColumnIndex(ex.Input, out int srcCol);
+ ColumnType typeSrc = row.Schema.GetColumnType(srcCol);
Type t = typeSrc.IsVector ? (ex.Ordered ? typeof(ImplVecOrdered<>) : typeof(ImplVec<>)) : typeof(ImplOne<>);
t = t.MakeGenericType(typeSrc.ItemType.RawType);
- var consTypes = new Type[] { typeof(IRow), typeof(OneToOneTransformBase.ColInfo), typeof(ColInfoEx), typeof(int), typeof(Delegate) };
+ var consTypes = new Type[] { typeof(IRow), typeof(ColumnInfo), typeof(int), typeof(Delegate) };
var constructorInfo = t.GetConstructor(consTypes);
- return (InvertHashHelper)constructorInfo.Invoke(new object[] { row, info, ex, invertHashMaxCount, dstGetter });
+ return (InvertHashHelper)constructorInfo.Invoke(new object[] { row, ex, invertHashMaxCount, dstGetter });
}
///
@@ -1095,7 +1138,7 @@ public int GetHashCode(KeyValuePair obj)
private IEqualityComparer GetSimpleComparer()
{
- Contracts.Assert(_info.TypeSrc.ItemType.RawType == typeof(T));
+ Contracts.Assert(_srcType.ItemType.RawType == typeof(T));
if (typeof(T) == typeof(ReadOnlyMemory))
{
// We are hashing twice, once to assign to the slot, and then again,
@@ -1103,7 +1146,7 @@ private IEqualityComparer GetSimpleComparer()
// same seed used to assign to a slot, or otherwise this per-slot hash
// would have a lot of collisions. We ensure that we have different
// hash function by inverting the seed's bits.
- var c = new TextEqualityComparer(~_ex.HashSeed);
+ var c = new TextEqualityComparer(~_ex.Seed);
return c as IEqualityComparer;
}
// I assume I hope correctly that the default .NET hash function for uint
@@ -1117,11 +1160,10 @@ private abstract class Impl : InvertHashHelper
{
protected readonly InvertHashCollector Collector;
- protected Impl(IRow row, ColInfo info, ColInfoEx ex, int invertHashMaxCount)
- : base(row, info, ex)
+ protected Impl(IRow row, ColumnInfo ex, int invertHashMaxCount)
+ : base(row, ex)
{
Contracts.AssertValue(row);
- Contracts.AssertValue(info);
Contracts.AssertValue(ex);
Collector = new InvertHashCollector(1 << ex.HashBits, invertHashMaxCount, GetTextMap(), GetComparer());
@@ -1129,7 +1171,7 @@ protected Impl(IRow row, ColInfo info, ColInfoEx ex, int invertHashMaxCount)
protected virtual ValueMapper GetTextMap()
{
- return InvertHashUtils.GetSimpleMapper(Row.Schema, _info.Source);
+ return InvertHashUtils.GetSimpleMapper(Row.Schema, _srcCol);
}
protected virtual IEqualityComparer GetComparer()
@@ -1151,10 +1193,10 @@ private sealed class ImplOne : Impl
private T _value;
private uint _hash;
- public ImplOne(IRow row, OneToOneTransformBase.ColInfo info, ColInfoEx ex, int invertHashMaxCount, Delegate dstGetter)
- : base(row, info, ex, invertHashMaxCount)
+ public ImplOne(IRow row, ColumnInfo ex, int invertHashMaxCount, Delegate dstGetter)
+ : base(row, ex, invertHashMaxCount)
{
- _srcGetter = Row.GetGetter(_info.Source);
+ _srcGetter = Row.GetGetter(_srcCol);
_dstGetter = dstGetter as ValueGetter;
Contracts.AssertValue(_dstGetter);
}
@@ -1185,10 +1227,10 @@ private sealed class ImplVec : Impl
private VBuffer _value;
private VBuffer _hash;
- public ImplVec(IRow row, OneToOneTransformBase.ColInfo info, ColInfoEx ex, int invertHashMaxCount, Delegate dstGetter)
- : base(row, info, ex, invertHashMaxCount)
+ public ImplVec(IRow row, ColumnInfo ex, int invertHashMaxCount, Delegate dstGetter)
+ : base(row, ex, invertHashMaxCount)
{
- _srcGetter = Row.GetGetter>(_info.Source);
+ _srcGetter = Row.GetGetter>(_srcCol);
_dstGetter = dstGetter as ValueGetter>;
Contracts.AssertValue(_dstGetter);
}
@@ -1214,17 +1256,17 @@ private sealed class ImplVecOrdered : Impl>
private VBuffer _value;
private VBuffer _hash;
- public ImplVecOrdered(IRow row, OneToOneTransformBase.ColInfo info, ColInfoEx ex, int invertHashMaxCount, Delegate dstGetter)
- : base(row, info, ex, invertHashMaxCount)
+ public ImplVecOrdered(IRow row, ColumnInfo ex, int invertHashMaxCount, Delegate dstGetter)
+ : base(row, ex, invertHashMaxCount)
{
- _srcGetter = Row.GetGetter>(_info.Source);
+ _srcGetter = Row.GetGetter>(_srcCol);
_dstGetter = dstGetter as ValueGetter>;
Contracts.AssertValue(_dstGetter);
}
protected override ValueMapper, StringBuilder> GetTextMap()
{
- var simple = InvertHashUtils.GetSimpleMapper(Row.Schema, _info.Source);
+ var simple = InvertHashUtils.GetSimpleMapper(Row.Schema, _srcCol);
return InvertHashUtils.GetPairMapper(simple);
}
@@ -1255,5 +1297,74 @@ public override void Process()
}
}
}
+
+ ///
+ /// Estimator for
+ ///
+ public sealed class HashEstimator : IEstimator
+ {
+ public const int NumBitsMin = 1;
+ public const int NumBitsLim = 32;
+
+ public static class Defaults
+ {
+ public const int HashBits = NumBitsLim - 1;
+ public const uint Seed = 314489979;
+ public const bool Ordered = false;
+ public const int InvertHash = 0;
+ }
+
+ private readonly IHost _host;
+ private readonly HashTransformer.ColumnInfo[] _columns;
+
+ public static bool IsColumnTypeValid(ColumnType type) => (type.ItemType.IsText || type.ItemType.IsKey || type.ItemType == NumberType.R4 || type.ItemType == NumberType.R8);
+
+ internal const string ExpectedColumnType = "Expected Text, Key, Single or Double item type";
+
+ ///
+ /// Convinence constructor for simple one column case
+ ///
+ /// Host Environment.
+ /// Name of the output column.
+ /// Name of the column to be transformed. If this is null '' will be used.
+ /// Number of bits to hash into. Must be between 1 and 31, inclusive.
+ /// Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit.
+ public HashEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null,
+ int hashBits = Defaults.HashBits, int invertHash = Defaults.InvertHash)
+ : this(env, new HashTransformer.ColumnInfo(inputColumn, outputColumn ?? inputColumn, hashBits: hashBits, invertHash: invertHash))
+ {
+ }
+
+ /// Host Environment.
+ /// Description of dataset columns and how to process them.
+ public HashEstimator(IHostEnvironment env, params HashTransformer.ColumnInfo[] columns)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ _host = env.Register(nameof(HashEstimator));
+ _columns = columns.ToArray();
+ }
+
+ public HashTransformer Fit(IDataView input) => new HashTransformer(_host, input, _columns);
+
+ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
+ {
+ _host.CheckValue(inputSchema, nameof(inputSchema));
+ var result = inputSchema.Columns.ToDictionary(x => x.Name);
+ foreach (var colInfo in _columns)
+ {
+ if (!inputSchema.TryFindColumn(colInfo.Input, out var col))
+ throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input);
+ if (!IsColumnTypeValid(col.ItemType))
+ throw _host.ExceptParam(nameof(inputSchema), ExpectedColumnType);
+ var metadata = new List();
+ if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.SlotNames, out var slotMeta))
+ metadata.Add(slotMeta);
+ if (colInfo.InvertHash != 0)
+ metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false));
+ result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, col.ItemType.IsVector ? SchemaShape.Column.VectorKind.Vector : SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true, new SchemaShape(metadata));
+ }
+ return new SchemaShape(result.Values);
+ }
+ }
}
diff --git a/src/Microsoft.ML.Data/Transforms/TermEstimator.cs b/src/Microsoft.ML.Data/Transforms/TermEstimator.cs
index 446140573c..9f8135252e 100644
--- a/src/Microsoft.ML.Data/Transforms/TermEstimator.cs
+++ b/src/Microsoft.ML.Data/Transforms/TermEstimator.cs
@@ -25,13 +25,13 @@ public static class Defaults
/// Convenience constructor for public facing API.
///
/// Host Environment.
- /// Name of the output column.
- /// Name of the column to be transformed. If this is null '' will be used.
+ /// Name of the output column.
+ /// Name of the column to be transformed. If this is null '' will be used.
/// Maximum number of terms to keep per column when auto-training.
/// How items should be ordered when vectorized. By default, they will be in the order encountered.
/// If by value items are sorted according to their default comparison, e.g., text sorting will be case sensitive (e.g., 'A' then 'Z' then 'a').
- public TermEstimator(IHostEnvironment env, string name, string source = null, int maxNumTerms = Defaults.MaxNumTerms, TermTransform.SortOrder sort = Defaults.Sort) :
- this(env, new TermTransform.ColumnInfo(name, source ?? name, maxNumTerms, sort))
+ public TermEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null, int maxNumTerms = Defaults.MaxNumTerms, TermTransform.SortOrder sort = Defaults.Sort) :
+ this(env, new TermTransform.ColumnInfo(inputColumn, outputColumn ?? inputColumn, maxNumTerms, sort))
{
}
diff --git a/src/Microsoft.ML.Transforms/CategoricalHashTransform.cs b/src/Microsoft.ML.Transforms/CategoricalHashTransform.cs
index 582586b9a0..5808a5e3e0 100644
--- a/src/Microsoft.ML.Transforms/CategoricalHashTransform.cs
+++ b/src/Microsoft.ML.Transforms/CategoricalHashTransform.cs
@@ -10,6 +10,7 @@
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Transforms;
[assembly: LoadableClass(CategoricalHashTransform.Summary, typeof(IDataTransform), typeof(CategoricalHashTransform), typeof(CategoricalHashTransform.Arguments), typeof(SignatureDataTransform),
CategoricalHashTransform.UserName, "CategoricalHashTransform", "CatHashTransform", "CategoricalHash", "CatHash")]
@@ -91,7 +92,7 @@ private static class Defaults
}
///
- /// This class is a merger of and
+ /// This class is a merger of and
/// with join option removed
///
public sealed class Arguments : TransformInputBase
@@ -169,13 +170,13 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
throw h.ExceptUserArg(nameof(args.HashBits), "Number of bits must be between 1 and {0}", NumBitsLim - 1);
// creating the Hash function
- var hashArgs = new HashTransform.Arguments
+ var hashArgs = new HashTransformer.Arguments
{
HashBits = args.HashBits,
Seed = args.Seed,
Ordered = args.Ordered,
InvertHash = args.InvertHash,
- Column = new HashTransform.Column[args.Column.Length]
+ Column = new HashTransformer.Column[args.Column.Length]
};
for (int i = 0; i < args.Column.Length; i++)
{
@@ -184,7 +185,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
throw h.ExceptUserArg(nameof(Column.Name));
h.Assert(!string.IsNullOrWhiteSpace(column.Name));
h.Assert(!string.IsNullOrWhiteSpace(column.Source));
- hashArgs.Column[i] = new HashTransform.Column
+ hashArgs.Column[i] = new HashTransformer.Column
{
HashBits = column.HashBits,
Seed = column.Seed,
@@ -198,7 +199,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
return CreateTransformCore(
args.OutputKind, args.Column,
args.Column.Select(col => col.OutputKind).ToList(),
- new HashTransform(h, hashArgs, input),
+ HashTransformer.Create(h, hashArgs, input),
h,
args);
}
diff --git a/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs b/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs
index 4341afbb92..d8a411d48c 100644
--- a/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs
+++ b/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs
@@ -13,6 +13,7 @@
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Transforms;
[assembly: LoadableClass(WordBagTransform.Summary, typeof(IDataTransform), typeof(WordBagTransform), typeof(WordBagTransform.Arguments), typeof(SignatureDataTransform),
"Word Bag Transform", "WordBagTransform", "WordBag")]
@@ -474,7 +475,7 @@ public interface INgramExtractorFactory
{
///
/// Whether the extractor transform created by this factory uses the hashing trick
- /// (by using or , for example).
+ /// (by using or , for example).
///
bool UseHashingTrick { get; }
diff --git a/src/Microsoft.ML.Transforms/Text/WordHashBagTransform.cs b/src/Microsoft.ML.Transforms/Text/WordHashBagTransform.cs
index b51cccd220..1515801bcb 100644
--- a/src/Microsoft.ML.Transforms/Text/WordHashBagTransform.cs
+++ b/src/Microsoft.ML.Transforms/Text/WordHashBagTransform.cs
@@ -11,6 +11,7 @@
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Transforms;
[assembly: LoadableClass(WordHashBagTransform.Summary, typeof(IDataTransform), typeof(WordHashBagTransform), typeof(WordHashBagTransform.Arguments), typeof(SignatureDataTransform),
"Word Hash Bag Transform", "WordHashBagTransform", "WordHashBag")]
@@ -266,7 +267,7 @@ public bool TryUnparse(StringBuilder sb)
}
///
- /// This class is a merger of and
+ /// This class is a merger of and
/// , with the ordered option,
/// the rehashUnigrams option and the allLength option removed.
///
@@ -340,7 +341,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
List termCols = null;
if (termLoaderArgs != null)
termCols = new List();
- var hashColumns = new List();
+ var hashColumns = new List();
var ngramHashColumns = new NgramHashTransform.Column[args.Column.Length];
var colCount = args.Column.Length;
@@ -371,7 +372,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
}
hashColumns.Add(
- new HashTransform.Column
+ new HashTransformer.Column
{
Name = tmpName,
Source = termLoaderArgs == null ? column.Source[isrc] : tmpName,
@@ -435,7 +436,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
// Args for the Hash function with multiple columns
var hashArgs =
- new HashTransform.Arguments
+ new HashTransformer.Arguments
{
HashBits = 31,
Seed = args.Seed,
@@ -444,7 +445,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
InvertHash = args.InvertHash
};
- view = new HashTransform(h, hashArgs, view);
+ view = HashTransformer.Create(h, hashArgs, view);
// creating the NgramHash function
var ngramHashArgs =
diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs
index 3abff3e560..e8bc11143e 100644
--- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs
+++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs
@@ -2,18 +2,12 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using Float = System.Single;
-
-using System;
-using System.Collections.Generic;
-using System.IO;
-using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
-using Microsoft.ML.Runtime.Data.IO;
-using Microsoft.ML.Runtime.Internal.Utilities;
-using Microsoft.ML.Runtime.Model;
using Microsoft.ML.Runtime.TextAnalytics;
+using Microsoft.ML.Transforms;
+using System;
using Xunit;
+using Float = System.Single;
namespace Microsoft.ML.Runtime.RunTests
{
@@ -82,14 +76,14 @@ private void TestHashTransformHelper(T[] data, uint[] results, NumberType typ
builder.AddColumn("F1", type, data);
var srcView = builder.GetDataView();
- HashTransform.Column col = new HashTransform.Column();
- col.Source = "F1";
+ var col = new HashTransformer.Column();
+ col.Name = "F1";
col.HashBits = 5;
col.Seed = 42;
- HashTransform.Arguments args = new HashTransform.Arguments();
- args.Column = new HashTransform.Column[] { col };
+ var args = new HashTransformer.Arguments();
+ args.Column = new HashTransformer.Column[] { col };
- var hashTransform = new HashTransform(Env, args, srcView);
+ var hashTransform = HashTransformer.Create(Env, args, srcView);
using (var cursor = hashTransform.GetRowCursor(c => true))
{
var resultGetter = cursor.GetGetter(1);
@@ -120,14 +114,14 @@ private void TestHashTransformVectorHelper(VBuffer data, uint[][] results,
private void TestHashTransformVectorHelper(ArrayDataViewBuilder builder, uint[][] results)
{
var srcView = builder.GetDataView();
- HashTransform.Column col = new HashTransform.Column();
- col.Source = "F1V";
+ var col = new HashTransformer.Column();
+ col.Name = "F1V";
col.HashBits = 5;
col.Seed = 42;
- HashTransform.Arguments args = new HashTransform.Arguments();
- args.Column = new HashTransform.Column[] { col };
+ var args = new HashTransformer.Arguments();
+ args.Column = new HashTransformer.Column[] { col };
- var hashTransform = new HashTransform(Env, args, srcView);
+ var hashTransform = HashTransformer.Create(Env, args, srcView);
using (var cursor = hashTransform.GetRowCursor(c => true))
{
var resultGetter = cursor.GetGetter>(1);
diff --git a/test/Microsoft.ML.Tests/Transformers/HashTests.cs b/test/Microsoft.ML.Tests/Transformers/HashTests.cs
new file mode 100644
index 0000000000..d52a800bb1
--- /dev/null
+++ b/test/Microsoft.ML.Tests/Transformers/HashTests.cs
@@ -0,0 +1,134 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime.Api;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Model;
+using Microsoft.ML.Runtime.RunTests;
+using Microsoft.ML.Runtime.Tools;
+using Microsoft.ML.Transforms;
+using System;
+using System.IO;
+using System.Linq;
+using Xunit;
+using Xunit.Abstractions;
+
+namespace Microsoft.ML.Tests.Transformers
+{
+ public class HashTests : TestDataPipeBase
+ {
+ public HashTests(ITestOutputHelper output) : base(output)
+ {
+ }
+
+ private class TestClass
+ {
+ public float A;
+ public float B;
+ public float C;
+ }
+
+ private class TestMeta
+ {
+ [VectorType(2)]
+ public float[] A;
+ public float B;
+ [VectorType(2)]
+ public double[] C;
+ public double D;
+ }
+
+ [Fact]
+ public void HashWorkout()
+ {
+ var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } };
+
+ var dataView = ComponentCreation.CreateDataView(Env, data);
+ var pipe = new HashEstimator(Env, new[]{
+ new HashTransformer.ColumnInfo("A", "HashA", hashBits:4, invertHash:-1),
+ new HashTransformer.ColumnInfo("B", "HashB", hashBits:3, ordered:true),
+ new HashTransformer.ColumnInfo("C", "HashC", seed:42),
+ new HashTransformer.ColumnInfo("A", "HashD"),
+ });
+
+ TestEstimatorCore(pipe, dataView);
+ Done();
+ }
+
+ [Fact]
+ public void TestMetadata()
+ {
+
+ var data = new[] {
+ new TestMeta() { A=new float[2] { 3.5f, 2.5f}, B=1, C= new double[2] { 5.1f, 6.1f}, D= 7},
+ new TestMeta() { A=new float[2] { 3.5f, 2.5f}, B=1, C= new double[2] { 5.1f, 6.1f}, D= 7},
+ new TestMeta() { A=new float[2] { 3.5f, 2.5f}, B=1, C= new double[2] { 5.1f, 6.1f}, D= 7}};
+
+
+ var dataView = ComponentCreation.CreateDataView(Env, data);
+ var pipe = new HashEstimator(Env, new[] {
+ new HashTransformer.ColumnInfo("A", "HashA", invertHash:1, hashBits:10),
+ new HashTransformer.ColumnInfo("A", "HashAUnlim", invertHash:-1, hashBits:10),
+ new HashTransformer.ColumnInfo("A", "HashAUnlimOrdered", invertHash:-1, hashBits:10, ordered:true)
+ });
+ var result = pipe.Fit(dataView).Transform(dataView);
+ ValidateMetadata(result);
+ Done();
+ }
+
+ private void ValidateMetadata(IDataView result)
+ {
+
+ Assert.True(result.Schema.TryGetColumnIndex("HashA", out int HashA));
+ Assert.True(result.Schema.TryGetColumnIndex("HashAUnlim", out int HashAUnlim));
+ Assert.True(result.Schema.TryGetColumnIndex("HashAUnlimOrdered", out int HashAUnlimOrdered));
+ VBuffer> keys = default;
+ var types = result.Schema.GetMetadataTypes(HashA);
+ Assert.Equal(types.Select(x => x.Key), new string[1] { MetadataUtils.Kinds.KeyValues });
+ result.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, HashA, ref keys);
+ Assert.True(keys.Length == 1024);
+ //REVIEW: This is weird. I specified invertHash to 1 so I expect only one value to be in key values, but i got two.
+ Assert.Equal(keys.Items().Select(x => x.Value.ToString()), new string[2] {"2.5", "3.5" });
+
+ types = result.Schema.GetMetadataTypes(HashAUnlim);
+ Assert.Equal(types.Select(x => x.Key), new string[1] { MetadataUtils.Kinds.KeyValues });
+ result.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, HashA, ref keys);
+ Assert.True(keys.Length == 1024);
+ Assert.Equal(keys.Items().Select(x => x.Value.ToString()), new string[2] { "2.5", "3.5" });
+
+ types = result.Schema.GetMetadataTypes(HashAUnlimOrdered);
+ Assert.Equal(types.Select(x => x.Key), new string[1] { MetadataUtils.Kinds.KeyValues });
+ result.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, HashA, ref keys);
+ Assert.True(keys.Length == 1024);
+ Assert.Equal(keys.Items().Select(x => x.Value.ToString()), new string[2] { "2.5", "3.5" });
+ }
+
+ [Fact]
+ public void TestCommandLine()
+ {
+ Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Hash{col=B:A} in=f:\2.txt" }), (int)0);
+ }
+
+ [Fact]
+ public void TestOldSavingAndLoading()
+ {
+ var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } };
+ var dataView = ComponentCreation.CreateDataView(Env, data);
+ var pipe = new HashEstimator(Env, new[]{
+ new HashTransformer.ColumnInfo("A", "HashA", hashBits:4, invertHash:-1),
+ new HashTransformer.ColumnInfo("B", "HashB", hashBits:3, ordered:true),
+ new HashTransformer.ColumnInfo("C", "HashC", seed:42),
+ new HashTransformer.ColumnInfo("A", "HashD"),
+ });
+ var result = pipe.Fit(dataView).Transform(dataView);
+ var resultRoles = new RoleMappedData(result);
+ using (var ms = new MemoryStream())
+ {
+ TrainUtils.SaveModel(Env, Env.Start("saving"), ms, null, resultRoles);
+ ms.Position = 0;
+ var loadedView = ModelFileUtils.LoadTransforms(Env, dataView, ms);
+ }
+ }
+ }
+}