Skip to content

Categorical estimator #899

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Sep 18, 2018
12 changes: 6 additions & 6 deletions src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Data.StaticPipe.Runtime;
using Microsoft.ML.Runtime;
Expand All @@ -16,11 +12,15 @@
using Microsoft.ML.Runtime.Model.Onnx;
using Microsoft.ML.Runtime.Model.Pfa;
using Newtonsoft.Json.Linq;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

[assembly: LoadableClass(KeyToVectorTransform.Summary, typeof(IDataTransform), typeof(KeyToVectorTransform), typeof(KeyToVectorTransform.Arguments), typeof(SignatureDataTransform),
"Key To Vector Transform", KeyToVectorTransform.UserName, "KeyToVector", "ToVector", DocName = "transform/KeyToVectorTransform.md")]

[assembly: LoadableClass(KeyToVectorTransform.Summary, typeof(IDataView), typeof(KeyToVectorTransform), null, typeof(SignatureLoadDataTransform),
[assembly: LoadableClass(KeyToVectorTransform.Summary, typeof(IDataTransform), typeof(KeyToVectorTransform), null, typeof(SignatureLoadDataTransform),
"Key To Vector Transform", KeyToVectorTransform.LoaderSignature)]

[assembly: LoadableClass(KeyToVectorTransform.Summary, typeof(KeyToVectorTransform), null, typeof(SignatureLoadModel),
Expand Down Expand Up @@ -733,7 +733,7 @@ public KeyToVectorEstimator(IHostEnvironment env, string name, string source = n
{
}

public KeyToVectorEstimator(IHostEnvironment env, KeyToVectorTransform transformer)
private KeyToVectorEstimator(IHostEnvironment env, KeyToVectorTransform transformer)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(KeyToVectorEstimator)), transformer)
{
}
Expand Down
29 changes: 22 additions & 7 deletions src/Microsoft.ML.Data/Transforms/TermEstimator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,31 @@
using Microsoft.ML.Data.StaticPipe.Runtime;
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;

namespace Microsoft.ML.Runtime.Data
{
public sealed class TermEstimator : IEstimator<TermTransform>
{
public static class Defaults
Copy link
Contributor Author

@Ivanidzo4ka Ivanidzo4ka Sep 13, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Defaults [](start = 28, length = 8)

Is this a right way to do? #Closed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good to me


In reply to: 217229486 [](ancestors = 217229486)

{
public const int MaxNumTerms = 1000000;
public const TermTransform.SortOrder Sort = TermTransform.SortOrder.Occurrence;
}

private readonly IHost _host;
private readonly TermTransform.ColumnInfo[] _columns;
public TermEstimator(IHostEnvironment env, string name, string source = null, int maxNumTerms = TermTransform.Defaults.MaxNumTerms, TermTransform.SortOrder sort = TermTransform.Defaults.Sort) :

/// <summary>
/// Convenience constructor for public facing API.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="name">Name of the output column.</param>
/// <param name="source">Name of the column to be transformed. If this is null '<paramref name="name"/>' will be used.</param>
/// <param name="maxNumTerms">Maximum number of terms to keep per column when auto-training.</param>
/// <param name="sort">How items should be ordered when vectorized. By default, they will be in the order encountered.
/// If by value items are sorted according to their default comparison, e.g., text sorting will be case sensitive (e.g., 'A' then 'Z' then 'a').</param>
public TermEstimator(IHostEnvironment env, string name, string source = null, int maxNumTerms = Defaults.MaxNumTerms, TermTransform.SortOrder sort = Defaults.Sort) :
this(env, new TermTransform.ColumnInfo(name, source ?? name, maxNumTerms, sort))
{
}
Expand Down Expand Up @@ -47,7 +62,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
if (!col.IsKey || !col.Metadata.TryFindColumn(MetadataUtils.Kinds.KeyValues, out var kv) || kv.Kind != SchemaShape.Column.VectorKind.Vector)
{
kv = new SchemaShape.Column(MetadataUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector,
col.ItemType, col.IsKey);
colInfo.TextKeyValues ? TextType.Instance : col.ItemType, col.IsKey);
}
Contracts.AssertValue(kv);

Expand Down Expand Up @@ -90,7 +105,7 @@ public sealed class ToKeyFitResult<T>
// At the moment this is empty. Once PR #863 clears, we can change this class to hold the output
// key-values metadata.

internal ToKeyFitResult(TermTransform.TermMap map)
public ToKeyFitResult(TermTransform.TermMap map)
{
}
}
Expand All @@ -101,8 +116,8 @@ public static partial class TermStaticExtensions
// Raw generics would allow illegal possible inputs, e.g., Scalar<Bitmap>. So, this is a partial
// class, and all the public facing extension methods for each possible type are in a T4 generated result.

private const KeyValueOrder DefSort = (KeyValueOrder)TermTransform.Defaults.Sort;
private const int DefMax = TermTransform.Defaults.MaxNumTerms;
private const KeyValueOrder DefSort = (KeyValueOrder)TermEstimator.Defaults.Sort;
private const int DefMax = TermEstimator.Defaults.MaxNumTerms;

private struct Config
{
Expand Down Expand Up @@ -176,7 +191,7 @@ public override IEstimator<ITransformer> Reconcile(IHostEnvironment env, Pipelin
{
var infos = new TermTransform.ColumnInfo[toOutput.Length];
Action<TermTransform> onFit = null;
for (int i=0; i<toOutput.Length; ++i)
for (int i = 0; i < toOutput.Length; ++i)
{
var tcol = (ITermCol)toOutput[i];
infos[i] = new TermTransform.ColumnInfo(inputNames[tcol.Input], outputNames[toOutput[i]],
Expand Down
20 changes: 7 additions & 13 deletions src/Microsoft.ML.Data/Transforms/TermTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
typeof(TermTransform.Arguments), typeof(SignatureDataTransform),
TermTransform.UserName, "Term", "AutoLabel", "TermTransform", "AutoLabelTransform", DocName = "transform/TermTransform.md")]

[assembly: LoadableClass(TermTransform.Summary, typeof(IDataView), typeof(TermTransform), null, typeof(SignatureLoadDataTransform),
[assembly: LoadableClass(TermTransform.Summary, typeof(IDataTransform), typeof(TermTransform), null, typeof(SignatureLoadDataTransform),
TermTransform.UserName, TermTransform.LoaderSignature)]

[assembly: LoadableClass(TermTransform.Summary, typeof(TermTransform), null, typeof(SignatureLoadModel),
Expand Down Expand Up @@ -101,16 +101,10 @@ public enum SortOrder : byte
// other things, like case insensitive (where appropriate), culturally aware, etc.?
}

internal static class Defaults
{
public const int MaxNumTerms = 1000000;
public const SortOrder Sort = SortOrder.Occurrence;
}

public abstract class ArgumentsBase : TransformInputBase
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Maximum number of terms to keep per column when auto-training", ShortName = "max", SortOrder = 5)]
public int MaxNumTerms = Defaults.MaxNumTerms;
public int MaxNumTerms = TermEstimator.Defaults.MaxNumTerms;

[Argument(ArgumentType.AtMostOnce, HelpText = "Comma separated list of terms", SortOrder = 105, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)]
public string Terms;
Expand All @@ -134,7 +128,7 @@ public abstract class ArgumentsBase : TransformInputBase
// REVIEW: Should we always sort? Opinions are mixed. See work item 7797429.
[Argument(ArgumentType.AtMostOnce, HelpText = "How items should be ordered when vectorized. By default, they will be in the order encountered. " +
"If by value items are sorted according to their default comparison, e.g., text sorting will be case sensitive (e.g., 'A' then 'Z' then 'a').", SortOrder = 113)]
public SortOrder Sort = Defaults.Sort;
public SortOrder Sort = TermEstimator.Defaults.Sort;

// REVIEW: Should we do this here, or correct the various pieces of code here and in MRS etc. that
// assume key-values will be string? Once we correct these things perhaps we can see about removing it.
Expand Down Expand Up @@ -164,7 +158,7 @@ public ColInfo(string name, string source, ColumnType type)

public class ColumnInfo
{
public ColumnInfo(string input, string output, int maxNumTerms = Defaults.MaxNumTerms, SortOrder sort = Defaults.Sort, string[] term = null, bool textKeyValues = false)
public ColumnInfo(string input, string output, int maxNumTerms = TermEstimator.Defaults.MaxNumTerms, SortOrder sort = TermEstimator.Defaults.Sort, string[] term = null, bool textKeyValues = false)
{
Input = input;
Output = output;
Expand All @@ -181,7 +175,7 @@ public ColumnInfo(string input, string output, int maxNumTerms = Defaults.MaxNum
public readonly string[] Term;
public readonly bool TextKeyValues;

internal string Terms { get; set; }
protected internal string Terms { get; set; }
}

public const string Summary = "Converts input values (words, numbers, etc.) to index in a dictionary.";
Expand Down Expand Up @@ -406,7 +400,7 @@ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISc
/// If by value items are sorted according to their default comparison, e.g., text sorting will be case sensitive (e.g., 'A' then 'Z' then 'a').</param>
public static IDataView Create(IHostEnvironment env,
IDataView input, string name, string source = null,
int maxNumTerms = Defaults.MaxNumTerms, SortOrder sort = Defaults.Sort) =>
int maxNumTerms = TermEstimator.Defaults.MaxNumTerms, SortOrder sort = TermEstimator.Defaults.Sort) =>
new TermTransform(env, input, new[] { new ColumnInfo(source ?? name, name, maxNumTerms, sort) }).MakeDataTransform(input);

public static IDataTransform Create(IHostEnvironment env, ArgumentsBase args, ColumnBase[] column, IDataView input)
Expand Down Expand Up @@ -710,7 +704,7 @@ public override void Save(ModelSaveContext ctx)
});
}

internal TermMap GetTermMap(int iinfo)
public TermMap GetTermMap(int iinfo)
{
Contracts.Assert(0 <= iinfo && iinfo < _unboundMaps.Length);
return _unboundMaps[iinfo];
Expand Down
12 changes: 6 additions & 6 deletions src/Microsoft.ML.Data/Transforms/TermTransformImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ private static BoundTermMap Bind(IHostEnvironment env, ISchema schema, TermMap u
/// These are the immutable and serializable analogs to the <see cref="Builder"/> used in
/// training.
/// </summary>
internal abstract class TermMap
public abstract class TermMap
{
/// <summary>
/// The item type of the input type, that is, either the input type or,
Expand Down Expand Up @@ -501,9 +501,9 @@ protected TermMap(PrimitiveType type, int count)
OutputType = new KeyType(DataKind.U4, 0, Count == 0 ? 1 : Count);
}

public abstract void Save(ModelSaveContext ctx, IHostEnvironment host, CodecFactory codecFactory);
internal abstract void Save(ModelSaveContext ctx, IHostEnvironment host, CodecFactory codecFactory);

public static TermMap Load(ModelLoadContext ctx, IHostEnvironment ectx, CodecFactory codecFactory)
internal static TermMap Load(ModelLoadContext ctx, IHostEnvironment ectx, CodecFactory codecFactory)
{
// *** Binary format ***
// byte: map type code
Expand Down Expand Up @@ -610,7 +610,7 @@ public static TextImpl Create(ModelLoadContext ctx, IExceptionContext ectx)
return new TextImpl(pool);
}

public override void Save(ModelSaveContext ctx, IHostEnvironment host, CodecFactory codecFactory)
internal override void Save(ModelSaveContext ctx, IHostEnvironment host, CodecFactory codecFactory)
{
// *** Binary format ***
// byte: map type code, in this case 'Text' (0)
Expand Down Expand Up @@ -685,7 +685,7 @@ public HashArrayImpl(PrimitiveType itemType, HashArray<T> values)
_values = values;
}

public override void Save(ModelSaveContext ctx, IHostEnvironment host, CodecFactory codecFactory)
internal override void Save(ModelSaveContext ctx, IHostEnvironment host, CodecFactory codecFactory)
{
// *** Binary format ***
// byte: map type code, in this case 'Codec'
Expand Down Expand Up @@ -757,7 +757,7 @@ public override void WriteTextTerms(TextWriter writer)
}
}

internal abstract class TermMap<T> : TermMap
public abstract class TermMap<T> : TermMap
{
protected TermMap(PrimitiveType type, int count)
: base(type, count)
Expand Down
83 changes: 72 additions & 11 deletions src/Microsoft.ML.Transforms/CategoricalHashTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,14 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Float = System.Single;

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Internal.Internallearn;

[assembly: LoadableClass(CategoricalHashTransform.Summary, typeof(IDataTransform), typeof(CategoricalHashTransform), typeof(CategoricalHashTransform.Arguments), typeof(SignatureDataTransform),
CategoricalHashTransform.UserName, "CategoricalHashTransform", "CatHashTransform", "CategoricalHash", "CatHash")]
Expand Down Expand Up @@ -62,14 +59,11 @@ protected override bool TryParse(string str)

// We accept N:B:S where N is the new column name, B is the number of bits,
// and S is source column names.
string extra;
if (!base.TryParse(str, out extra))
if (!TryParse(str, out string extra))
return false;
if (extra == null)
return true;

int bits;
if (!int.TryParse(extra, out bits))
if (!int.TryParse(extra, out int bits))
return false;
HashBits = bits;
return true;
Expand Down Expand Up @@ -201,14 +195,81 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
};
}

return CategoricalTransform.CreateTransformCore(
return CreateTransformCore(
args.OutputKind, args.Column,
args.Column.Select(col => col.OutputKind).ToList(),
new HashTransform(h, hashArgs, input),
h,
env,
args);
}
}

private static IDataTransform CreateTransformCore(CategoricalTransform.OutputKind argsOutputKind, OneToOneColumn[] columns,
List<CategoricalTransform.OutputKind?> columnOutputKinds, IDataTransform input, IHost h, Arguments catHashArgs = null)
{
Contracts.CheckValue(columns, nameof(columns));
Contracts.CheckValue(columnOutputKinds, nameof(columnOutputKinds));
Contracts.CheckParam(columns.Length == columnOutputKinds.Count, nameof(columns));

using (var ch = h.Start("Create Transform Core"))
{
// Create the KeyToVectorTransform, if needed.
var cols = new List<KeyToVectorTransform.Column>();
bool binaryEncoding = argsOutputKind == CategoricalTransform.OutputKind.Bin;
for (int i = 0; i < columns.Length; i++)
{
var column = columns[i];
if (!column.TrySanitize())
throw h.ExceptUserArg(nameof(Column.Name));

bool? bag;
CategoricalTransform.OutputKind kind = columnOutputKinds[i] ?? argsOutputKind;
switch (kind)
{
default:
throw ch.ExceptUserArg(nameof(Column.OutputKind));
case CategoricalTransform.OutputKind.Key:
continue;
case CategoricalTransform.OutputKind.Bin:
binaryEncoding = true;
bag = false;
break;
case CategoricalTransform.OutputKind.Ind:
bag = false;
break;
case CategoricalTransform.OutputKind.Bag:
bag = true;
break;
}
var col = new KeyToVectorTransform.Column();
Copy link
Contributor

@Zruty0 Zruty0 Sep 13, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Column [](start = 55, length = 6)

object initializer? #ByDesign

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's gonna be dead after hash transform become estimator.


In reply to: 217236161 [](ancestors = 217236161)

col.Name = column.Name;
col.Source = column.Name;
col.Bag = bag;
cols.Add(col);
}

if (cols.Count == 0)
return input;

IDataTransform transform;
if (binaryEncoding)
{
if ((catHashArgs?.InvertHash ?? 0) != 0)
ch.Warning("Invert hashing is being used with binary encoding.");

var keyToBinaryVecCols = cols.Select(x => new KeyToBinaryVectorTransform.ColumnInfo(x.Source, x.Name)).ToArray();
transform = KeyToBinaryVectorTransform.Create(h, input, keyToBinaryVecCols);
}
else
{
var keyToVecCols = cols.Select(x => new KeyToVectorTransform.ColumnInfo(x.Source, x.Name, x.Bag ?? argsOutputKind == CategoricalTransform.OutputKind.Bag)).ToArray();

transform = KeyToVectorTransform.Create(h, input, keyToVecCols);
}

ch.Done();
return transform;
}
}
}
}
Loading