Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,014 changes: 591 additions & 423 deletions src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs

Large diffs are not rendered by default.

43 changes: 10 additions & 33 deletions src/Microsoft.ML.Data/Transforms/TermTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,13 @@ private static (string input, string output)[] GetColumnPairs(ColumnInfo[] colum
return columns.Select(x => (x.Input, x.Output)).ToArray();
}

internal static string TestIsKnownDataKind(ColumnType type)
Copy link
Contributor Author

@Ivanidzo4ka Ivanidzo4ka Sep 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

static [](start = 17, length = 6)

don't have to be static #Resolved

{
if (type.ItemType.RawKind != default && (type.IsVector || type.IsPrimitive))
return null;
return "standard type or a vector of standard type";
}

private ColInfo[] CreateInfos(ISchema schema)
{
Host.AssertValue(schema);
Expand All @@ -252,7 +259,7 @@ private ColInfo[] CreateInfos(ISchema schema)
var type = schema.GetColumnType(colSrc);
string reason = TestIsKnownDataKind(type);
if (reason != null)
throw Host.ExceptUserArg(nameof(ColumnPairs), InvalidTypeErrorFormat, ColumnPairs[i].input, type, reason);
throw Host.ExceptSchemaMismatch(nameof(ColumnPairs), "input", ColumnPairs[i].input, reason, type.ToString());
infos[i] = new ColInfo(ColumnPairs[i].output, ColumnPairs[i].input, type);
}
return infos;
Expand All @@ -271,7 +278,7 @@ private TermTransform(IHostEnvironment env, IDataView input,
{
using (var ch = Host.Start("Training"))
{
var infos = CreateInfos(Host, ColumnPairs, input.Schema, TestIsKnownDataKind);
var infos = CreateInfos(input.Schema);
_unboundMaps = Train(Host, ch, infos, file, termsColumn, loaderFactory, columns, input);
_textMetadata = new bool[_unboundMaps.Length];
for (int iinfo = 0; iinfo < columns.Length; ++iinfo)
Expand Down Expand Up @@ -403,29 +410,6 @@ public static IDataView Create(IHostEnvironment env,
//REVIEW: This and static method below need to go to base class as it get created.
private const string InvalidTypeErrorFormat = "Source column '{0}' has invalid type ('{1}'): {2}.";

private static ColInfo[] CreateInfos(IHostEnvironment env, (string source, string name)[] columns, ISchema schema, Func<ColumnType, string> testType)
{
env.CheckUserArg(Utils.Size(columns) > 0, nameof(columns));
env.AssertValue(schema);
env.AssertValueOrNull(testType);

var infos = new ColInfo[columns.Length];
for (int i = 0; i < columns.Length; i++)
{
if (!schema.TryGetColumnIndex(columns[i].source, out int colSrc))
throw env.ExceptUserArg(nameof(columns), "Source column '{0}' not found", columns[i].source);
var type = schema.GetColumnType(colSrc);
if (testType != null)
{
string reason = testType(type);
if (reason != null)
throw env.ExceptUserArg(nameof(columns), InvalidTypeErrorFormat, columns[i].source, type, reason);
}
infos[i] = new ColInfo(columns[i].name, columns[i].source, type);
}
return infos;
}

public static IDataTransform Create(IHostEnvironment env, ArgumentsBase args, ColumnBase[] column, IDataView input)
{
return Create(env, new Arguments()
Expand All @@ -452,13 +436,6 @@ public static IDataTransform Create(IHostEnvironment env, ArgumentsBase args, Co
}, input);
}

internal static string TestIsKnownDataKind(ColumnType type)
{
if (type.ItemType.RawKind != default && (type.IsVector || type.IsPrimitive))
return null;
return "Expected standard type or a vector of standard type";
}

/// <summary>
/// Utility method to create the file-based <see cref="TermMap"/>.
/// </summary>
Expand Down Expand Up @@ -701,7 +678,7 @@ public override void Save(ModelSaveContext ctx)
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());

base.SaveColumns(ctx);
SaveColumns(ctx);

Host.Assert(_unboundMaps.Length == _textMetadata.Length);
Host.Assert(_textMetadata.Length == ColumnPairs.Length);
Expand Down
2 changes: 0 additions & 2 deletions src/Microsoft.ML.Data/Transforms/TermTransformImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1051,7 +1051,6 @@ public override void AddMetadata(ColumnMetadataInfo colMetaInfo)
MetadataUtils.MetadataGetter<VBuffer<DvText>> getter =
(int iinfo, ref VBuffer<DvText> dst) =>
{
_host.Assert(iinfo == _iinfo);
// No buffer sharing convenient here.
VBuffer<T> dstT = default(VBuffer<T>);
TypedMap.GetTerms(ref dstT);
Expand All @@ -1066,7 +1065,6 @@ public override void AddMetadata(ColumnMetadataInfo colMetaInfo)
MetadataUtils.MetadataGetter<VBuffer<T>> getter =
(int iinfo, ref VBuffer<T> dst) =>
{
_host.Assert(iinfo == _iinfo);
TypedMap.GetTerms(ref dst);
};
var columnType = new VectorType(TypedMap.ItemType, TypedMap.OutputType.KeyCount);
Expand Down
15 changes: 5 additions & 10 deletions src/Microsoft.ML.Transforms/CategoricalTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ public static IDataTransform CreateTransformCore(
using (var ch = h.Start("Create Transform Core"))
{
// Create the KeyToVectorTransform, if needed.
List<KeyToVectorTransform.Column> cols = new List<KeyToVectorTransform.Column>();
var cols = new List<KeyToVectorTransform.Column>();
bool binaryEncoding = argsOutputKind == OutputKind.Bin;
for (int i = 0; i < columns.Length; i++)
{
Expand Down Expand Up @@ -220,19 +220,14 @@ public static IDataTransform CreateTransformCore(
if ((catHashArgs?.InvertHash ?? 0) != 0)
ch.Warning("Invert hashing is being used with binary encoding.");

var keyToBinaryArgs = new KeyToBinaryVectorTransform.Arguments();
keyToBinaryArgs.Column = cols.ToArray();
transform = new KeyToBinaryVectorTransform(h, keyToBinaryArgs, input);
var keyToBinaryVecCols = cols.Select(x => new KeyToBinaryVectorTransform.ColumnInfo(x.Source, x.Name)).ToArray();
transform = KeyToBinaryVectorTransform.Create(h, input, keyToBinaryVecCols);
}
else
{
var keyToVecArgs = new KeyToVectorTransform.Arguments
{
Bag = argsOutputKind == OutputKind.Bag,
Column = cols.ToArray()
};
var keyToVecCols = cols.Select(x => new KeyToVectorTransform.ColumnInfo(x.Source, x.Name, x.Bag ?? argsOutputKind == OutputKind.Bag)).ToArray();

transform = new KeyToVectorTransform(h, keyToVecArgs, input);
transform = KeyToVectorTransform.Create(h, input, keyToVecCols);
}

ch.Done();
Expand Down
Loading