Skip to content

KeyToVector estimators #858

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Sep 12, 2018
1,217 changes: 819 additions & 398 deletions src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs

Large diffs are not rendered by default.

62 changes: 15 additions & 47 deletions src/Microsoft.ML.Data/Transforms/TermTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -241,18 +241,25 @@ private static (string input, string output)[] GetColumnPairs(ColumnInfo[] colum
return columns.Select(x => (x.Input, x.Output)).ToArray();
}

private ColInfo[] CreateInfos(ISchema schema)
internal string TestIsKnownDataKind(ColumnType type)
{
Host.AssertValue(schema);
if (type.ItemType.RawKind != default && (type.IsVector || type.IsPrimitive))
return null;
return "standard type or a vector of standard type";
}

private ColInfo[] CreateInfos(ISchema inputSchema)
{
Host.AssertValue(inputSchema);
var infos = new ColInfo[ColumnPairs.Length];
for (int i = 0; i < ColumnPairs.Length; i++)
{
if (!schema.TryGetColumnIndex(ColumnPairs[i].input, out int colSrc))
throw Host.ExceptUserArg(nameof(ColumnPairs), "Source column '{0}' not found", ColumnPairs[i].input);
var type = schema.GetColumnType(colSrc);
if (!inputSchema.TryGetColumnIndex(ColumnPairs[i].input, out int colSrc))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[i].input);
var type = inputSchema.GetColumnType(colSrc);
string reason = TestIsKnownDataKind(type);
if (reason != null)
throw Host.ExceptUserArg(nameof(ColumnPairs), InvalidTypeErrorFormat, ColumnPairs[i].input, type, reason);
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[i].input, reason, type.ToString());
infos[i] = new ColInfo(ColumnPairs[i].output, ColumnPairs[i].input, type);
}
return infos;
Expand All @@ -271,7 +278,7 @@ private TermTransform(IHostEnvironment env, IDataView input,
{
using (var ch = Host.Start("Training"))
{
var infos = CreateInfos(Host, ColumnPairs, input.Schema, TestIsKnownDataKind);
var infos = CreateInfos(input.Schema);
_unboundMaps = Train(Host, ch, infos, file, termsColumn, loaderFactory, columns, input);
_textMetadata = new bool[_unboundMaps.Length];
for (int iinfo = 0; iinfo < columns.Length; ++iinfo)
Expand Down Expand Up @@ -400,32 +407,6 @@ public static IDataView Create(IHostEnvironment env,
int maxNumTerms = Defaults.MaxNumTerms, SortOrder sort = Defaults.Sort) =>
new TermTransform(env, input, new[] { new ColumnInfo(source ?? name, name, maxNumTerms, sort) }).MakeDataTransform(input);

//REVIEW: This and static method below need to go to base class as it get created.
private const string InvalidTypeErrorFormat = "Source column '{0}' has invalid type ('{1}'): {2}.";

private static ColInfo[] CreateInfos(IHostEnvironment env, (string source, string name)[] columns, ISchema schema, Func<ColumnType, string> testType)
{
env.CheckUserArg(Utils.Size(columns) > 0, nameof(columns));
env.AssertValue(schema);
env.AssertValueOrNull(testType);

var infos = new ColInfo[columns.Length];
for (int i = 0; i < columns.Length; i++)
{
if (!schema.TryGetColumnIndex(columns[i].source, out int colSrc))
throw env.ExceptUserArg(nameof(columns), "Source column '{0}' not found", columns[i].source);
var type = schema.GetColumnType(colSrc);
if (testType != null)
{
string reason = testType(type);
if (reason != null)
throw env.ExceptUserArg(nameof(columns), InvalidTypeErrorFormat, columns[i].source, type, reason);
}
infos[i] = new ColInfo(columns[i].name, columns[i].source, type);
}
return infos;
}

public static IDataTransform Create(IHostEnvironment env, ArgumentsBase args, ColumnBase[] column, IDataView input)
{
return Create(env, new Arguments()
Expand All @@ -452,13 +433,6 @@ public static IDataTransform Create(IHostEnvironment env, ArgumentsBase args, Co
}, input);
}

internal static string TestIsKnownDataKind(ColumnType type)
{
if (type.ItemType.RawKind != default && (type.IsVector || type.IsPrimitive))
return null;
return "Expected standard type or a vector of standard type";
}

/// <summary>
/// Utility method to create the file-based <see cref="TermMap"/>.
/// </summary>
Expand Down Expand Up @@ -701,7 +675,7 @@ public override void Save(ModelSaveContext ctx)
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());

base.SaveColumns(ctx);
SaveColumns(ctx);

Host.Assert(_unboundMaps.Length == _textMetadata.Length);
Host.Assert(_textMetadata.Length == ColumnPairs.Length);
Expand Down Expand Up @@ -743,12 +717,6 @@ internal TermMap GetTermMap(int iinfo)
protected override IRowMapper MakeRowMapper(ISchema schema)
=> new Mapper(this, schema);

protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol)
{
if ((inputSchema.GetColumnType(srcCol).ItemType.RawKind == default))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input, "image", inputSchema.GetColumnType(srcCol).ToString());
}

private sealed class Mapper : MapperBase, ISaveAsOnnx, ISaveAsPfa
{
private readonly ColumnType[] _types;
Expand Down
15 changes: 5 additions & 10 deletions src/Microsoft.ML.Transforms/CategoricalTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ public static IDataTransform CreateTransformCore(
using (var ch = h.Start("Create Transform Core"))
{
// Create the KeyToVectorTransform, if needed.
List<KeyToVectorTransform.Column> cols = new List<KeyToVectorTransform.Column>();
var cols = new List<KeyToVectorTransform.Column>();
bool binaryEncoding = argsOutputKind == OutputKind.Bin;
for (int i = 0; i < columns.Length; i++)
{
Expand Down Expand Up @@ -220,19 +220,14 @@ public static IDataTransform CreateTransformCore(
if ((catHashArgs?.InvertHash ?? 0) != 0)
ch.Warning("Invert hashing is being used with binary encoding.");

var keyToBinaryArgs = new KeyToBinaryVectorTransform.Arguments();
keyToBinaryArgs.Column = cols.ToArray();
transform = new KeyToBinaryVectorTransform(h, keyToBinaryArgs, input);
var keyToBinaryVecCols = cols.Select(x => new KeyToBinaryVectorTransform.ColumnInfo(x.Source, x.Name)).ToArray();
transform = KeyToBinaryVectorTransform.Create(h, input, keyToBinaryVecCols);
}
else
{
var keyToVecArgs = new KeyToVectorTransform.Arguments
{
Bag = argsOutputKind == OutputKind.Bag,
Column = cols.ToArray()
};
var keyToVecCols = cols.Select(x => new KeyToVectorTransform.ColumnInfo(x.Source, x.Name, x.Bag ?? argsOutputKind == OutputKind.Bag)).ToArray();

transform = new KeyToVectorTransform(h, keyToVecArgs, input);
transform = KeyToVectorTransform.Create(h, input, keyToVecCols);
}

ch.Done();
Expand Down
Loading