Skip to content

No loader on APIs for ValueToKey/OneHotEncoding #2245

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 27, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs
Original file line number Diff line number Diff line change
@@ -199,13 +199,13 @@ public ColumnSelectingTransformer(IHostEnvironment env, string[] keepColumns, st
_host.CheckValueOrNull(keepColumns);
_host.CheckValueOrNull(dropColumns);

bool keepValid = keepColumns != null && keepColumns.Count() > 0;
bool dropValid = dropColumns != null && dropColumns.Count() > 0;
bool keepValid = Utils.Size(keepColumns) > 0;
bool dropValid = Utils.Size(dropColumns) > 0;

// Check that both are not valid
_host.Check(!(keepValid && dropValid), "Both keepColumns and dropColumns are set, only one can be specified.");
_host.Check(!(keepValid && dropValid), "Both " + nameof(keepColumns) + " and " + nameof(dropColumns) + " are set. Exactly one can be specified.");
// Check that both are invalid
_host.Check(!(!keepValid && !dropValid), "Neither keepColumns or dropColumns is set, one must be specified.");
_host.Check(!(!keepValid && !dropValid), "Neither " + nameof(keepColumns) + " and " + nameof(dropColumns) + " is set. Exactly one must be specified.");

_selectedColumns = (keepValid) ? keepColumns : dropColumns;
KeepColumns = keepValid;
@@ -558,7 +558,7 @@ private static int[] BuildOutputToInputMap(IEnumerable<string> selectedColumns,
// given an input of ABC and dropping column B will result in AC.
// In drop mode, we drop all columns with the specified names and keep all the rest,
// ignoring the keepHidden argument.
for(int colIdx = 0; colIdx < inputSchema.Count; colIdx++)
for (int colIdx = 0; colIdx < inputSchema.Count; colIdx++)
{
if (selectedColumns.Contains(inputSchema[colIdx].Name))
continue;
15 changes: 6 additions & 9 deletions src/Microsoft.ML.Data/Transforms/ConversionsExtensionsCatalog.cs
Original file line number Diff line number Diff line change
@@ -112,19 +112,16 @@ public static ValueToKeyMappingEstimator MapValueToKey(this TransformsCatalog.Co
=> new ValueToKeyMappingEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn, maxNumTerms, sort);

/// <summary>
/// Converts value types into <see cref="KeyType"/> loading the keys to use from <paramref name="file"/>.
/// Converts value types into <see cref="KeyType"/>, optionally loading the keys to use from <paramref name="keyData"/>.
/// </summary>
/// <param name="catalog">The categorical transform's catalog.</param>
/// <param name="columns">The data columns to map to keys.</param>
/// <param name="file">The path of the file containing the terms.</param>
/// <param name="termsColumn"></param>
/// <param name="loaderFactory"></param>
/// <param name="keyData">The data view containing the terms. If specified, this should be a single column data
/// view, and the key-values will be taken from taht column. If unspecified, the key-values will be determined
/// from the input data upon fitting.</param>
public static ValueToKeyMappingEstimator MapValueToKey(this TransformsCatalog.ConversionTransforms catalog,
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Jan 25, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ValueToKeyMappingEstimator [](start = 22, length = 26)

I've tried to look it up, but github doesn't allow you to smoothly explore file changes if file get renamed.
Why we even have this public constructor?
We already have ValueMapping which accepts IEnumerable for keys and ZeeshanA adding another constructor which accepts IDataView.
Considering all this restriction required on IDataView, don't you think it's easier to use ValueMapping transform?

It just a question, not a PR blocker.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Ivan, I don't understand this question, could you clarify?

Are you asking why there is a public constructor on the estimators? I did not change the public/non-public nature of this constructor of the estimator. If you are pointing out that the estimator constructors should be non-public since we prefer things be created through MLContext, I agree, and I believe there is already an issue tracking this, #2100. However, the issue I am addressing here is orthogonal to that. I agree it should be done. I see no one is assigned to it.

Also I am not sure what you are talking about w.r.t. another constructor. Certainly I see no other constructor here. I am trying to address the issue I see that whoever did this original estimator API decided, for whatever reason, that using component factories and IDataLoader was a great idea, as discussed in the issue. The most obvious way to maintain the same sort of functionality (loading something from a file) is to provide the constructor I have provided. Are you disagreeing with that?

Also I did not rename the file?

But maybe I am misunderstanding your question. Perhaps you could clarify.

ValueToKeyMappingTransformer.ColumnInfo[] columns,
string file = null,
string termsColumn = null,
IComponentFactory<IMultiStreamSource, IDataLoader> loaderFactory = null)
=> new ValueToKeyMappingEstimator(CatalogUtils.GetEnvironment(catalog), columns, file, termsColumn, loaderFactory);
ValueToKeyMappingTransformer.ColumnInfo[] columns, IDataView keyData = null)
=> new ValueToKeyMappingEstimator(CatalogUtils.GetEnvironment(catalog), columns, keyData);

/// <summary>
/// Maps specified keys to specified values
27 changes: 15 additions & 12 deletions src/Microsoft.ML.Data/Transforms/ValueToKeyMappingEstimator.cs
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
namespace Microsoft.ML.Transforms.Conversions
{
/// <include file='doc.xml' path='doc/members/member[@name="ValueToKeyMappingEstimator"]/*' />
public sealed class ValueToKeyMappingEstimator: IEstimator<ValueToKeyMappingTransformer>
public sealed class ValueToKeyMappingEstimator : IEstimator<ValueToKeyMappingTransformer>
{
public static class Defaults
{
@@ -19,9 +19,7 @@ public static class Defaults

private readonly IHost _host;
private readonly ValueToKeyMappingTransformer.ColumnInfo[] _columns;
private readonly string _file;
private readonly string _termsColumn;
private readonly IComponentFactory<IMultiStreamSource, IDataLoader> _loaderFactory;
private readonly IDataView _keyData;

/// <summary>
/// Initializes a new instance of <see cref="ValueToKeyMappingEstimator"/>.
@@ -33,23 +31,28 @@ public static class Defaults
/// <param name="sort">How items should be ordered when vectorized. If <see cref="ValueToKeyMappingTransformer.SortOrder.Occurrence"/> choosen they will be in the order encountered.
/// If <see cref="ValueToKeyMappingTransformer.SortOrder.Value"/>, items are sorted according to their default comparison, for example, text sorting will be case sensitive (for example, 'A' then 'Z' then 'a').</param>
public ValueToKeyMappingEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null, int maxNumTerms = Defaults.MaxNumTerms, ValueToKeyMappingTransformer.SortOrder sort = Defaults.Sort) :
this(env, new [] { new ValueToKeyMappingTransformer.ColumnInfo(inputColumn, outputColumn ?? inputColumn, maxNumTerms, sort) })
this(env, new[] { new ValueToKeyMappingTransformer.ColumnInfo(inputColumn, outputColumn ?? inputColumn, maxNumTerms, sort) })
{
}

public ValueToKeyMappingEstimator(IHostEnvironment env, ValueToKeyMappingTransformer.ColumnInfo[] columns,
string file = null, string termsColumn = null,
IComponentFactory<IMultiStreamSource, IDataLoader> loaderFactory = null)
public ValueToKeyMappingEstimator(IHostEnvironment env, ValueToKeyMappingTransformer.ColumnInfo[] columns, IDataView keyData = null)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(ValueToKeyMappingEstimator));
_host.CheckNonEmpty(columns, nameof(columns));
_host.CheckValueOrNull(keyData);
if (keyData != null && keyData.Schema.Count != 1)
{
throw _host.ExceptParam(nameof(keyData), "If specified, this data view should contain only a single column " +
$"containing the terms to map, but this had {keyData.Schema.Count} columns.");

}

_columns = columns;
_file = file;
_termsColumn = termsColumn;
_loaderFactory = loaderFactory;
_keyData = keyData;
}

public ValueToKeyMappingTransformer Fit(IDataView input) => new ValueToKeyMappingTransformer(_host, input, _columns, _file, _termsColumn, _loaderFactory);
public ValueToKeyMappingTransformer Fit(IDataView input) => new ValueToKeyMappingTransformer(_host, input, _columns, _keyData, false);

public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
115 changes: 74 additions & 41 deletions src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs
Original file line number Diff line number Diff line change
@@ -290,24 +290,20 @@ private ColInfo[] CreateInfos(Schema inputSchema)

internal ValueToKeyMappingTransformer(IHostEnvironment env, IDataView input,
params ColumnInfo[] columns) :
this(env, input, columns, null, null, null)
this(env, input, columns, null, false)
{ }

internal ValueToKeyMappingTransformer(IHostEnvironment env, IDataView input,
ColumnInfo[] columns,
string file = null, string termsColumn = null,
IComponentFactory<IMultiStreamSource, IDataLoader> loaderFactory = null)
ColumnInfo[] columns, IDataView keyData, bool autoConvert)
: base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns))
{
using (var ch = Host.Start("Training"))
{
var infos = CreateInfos(input.Schema);
_unboundMaps = Train(Host, ch, infos, file, termsColumn, loaderFactory, columns, input);
_unboundMaps = Train(Host, ch, infos, keyData, columns, input, autoConvert);
_textMetadata = new bool[_unboundMaps.Length];
for (int iinfo = 0; iinfo < columns.Length; ++iinfo)
{
_textMetadata[iinfo] = columns[iinfo].TextKeyValues;
}
ch.Assert(_unboundMaps.Length == columns.Length);
}
}
@@ -348,8 +344,9 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat
item.TextKeyValues ?? args.TextKeyValues);
cols[i].Terms = item.Terms ?? args.Terms;
};
var keyData = GetKeyDataViewOrNull(env, ch, args.DataFile, args.TermsColumn, args.Loader, out bool autoLoaded);
return new ValueToKeyMappingTransformer(env, input, cols, keyData, autoLoaded).MakeDataTransform(input);
}
return new ValueToKeyMappingTransformer(env, input, cols, args.DataFile, args.TermsColumn, args.Loader).MakeDataTransform(input);
}

// Factory method for SignatureLoadModel.
@@ -416,29 +413,44 @@ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Sch
=> Create(env, ctx).MakeRowMapper(inputSchema);

/// <summary>
/// Utility method to create the file-based <see cref="TermMap"/>.
/// Returns a single-column <see cref="IDataView"/>, based on values from <see cref="Arguments"/>,
/// in the case where <see cref="ArgumentsBase.DataFile"/> is set. If that is not set, this will
/// return <see langword="null"/>.
/// </summary>
private static TermMap CreateFileTermMap(IHostEnvironment env, IChannel ch, string file, string termsColumn,
IComponentFactory<IMultiStreamSource, IDataLoader> loaderFactory, Builder bldr)
/// <param name="env">The host environment.</param>
/// <param name="ch">The host channel to use to mark exceptions and log messages.</param>
/// <param name="file">The name of the file. Must be specified if this method is called.</param>
/// <param name="termsColumn">The single column to select out of this transform. If not specified,
/// this method will attempt to guess.</param>
/// <param name="loaderFactory">The loader creator. If <see langword="null"/> we will attempt to determine
/// this </param>
/// <param name="autoConvert">Whether we should try to convert to the desired type by ourselves when doing
/// the term map. This will not be true in the case that the loader was adequately specified automatically.</param>
/// <returns>The single-column data containing the term data from the file.</returns>
[BestFriend]
internal static IDataView GetKeyDataViewOrNull(IHostEnvironment env, IChannel ch,
string file, string termsColumn, IComponentFactory<IMultiStreamSource, IDataLoader> loaderFactory,
out bool autoConvert)
{
Contracts.AssertValue(ch);
ch.AssertValue(env);
ch.Assert(!string.IsNullOrWhiteSpace(file));
ch.AssertValue(bldr);
ch.AssertValueOrNull(file);
ch.AssertValueOrNull(termsColumn);
ch.AssertValueOrNull(loaderFactory);

// If the user manually specifies a loader, or this is already a pre-processed binary
// file, then we assume the user knows what they're doing when they are so explicit,
// and do not attempt to convert to the desired type ourselves.
autoConvert = false;
if (string.IsNullOrWhiteSpace(file))
return null;

// First column using the file.
string src = termsColumn;
IMultiStreamSource fileSource = new MultiFileSource(file);

// If the user manually specifies a loader, or this is already a pre-processed binary
// file, then we assume the user knows what they're doing and do not attempt to convert
// to the desired type ourselves.
bool autoConvert = false;
IDataView termData;
IDataView keyData;
if (loaderFactory != null)
{
termData = loaderFactory.CreateComponent(env, fileSource);
}
keyData = loaderFactory.CreateComponent(env, fileSource);
else
{
// Determine the default loader from the extension.
@@ -451,44 +463,65 @@ private static TermMap CreateFileTermMap(IHostEnvironment env, IChannel ch, stri
ch.CheckUserArg(!string.IsNullOrWhiteSpace(src), nameof(termsColumn),
"Must be specified");
if (isBinary)
termData = new BinaryLoader(env, new BinaryLoader.Arguments(), fileSource);
keyData = new BinaryLoader(env, new BinaryLoader.Arguments(), fileSource);
else
{
ch.Assert(isTranspose);
termData = new TransposeLoader(env, new TransposeLoader.Arguments(), fileSource);
keyData = new TransposeLoader(env, new TransposeLoader.Arguments(), fileSource);
}
}
else
{
if (!string.IsNullOrWhiteSpace(src))
{
ch.Warning(
"{0} should not be specified when default loader is TextLoader. Ignoring {0}={1}",
"{0} should not be specified when default loader is " + nameof(TextLoader) + ". Ignoring {0}={1}",
nameof(Arguments.TermsColumn), src);
}
termData = new TextLoader(env,
keyData = new TextLoader(env,
columns: new[] { new TextLoader.Column("Term", DataKind.TX, 0) },
dataSample: fileSource)
.Read(fileSource);
src = "Term";
// In this case they are relying on heuristics, so auto-loading in this case is most appropriate.
autoConvert = true;
}
}
ch.AssertNonEmpty(src);

int colSrc;
if (!termData.Schema.TryGetColumnIndex(src, out colSrc))
if (keyData.Schema.GetColumnOrNull(src) == null)
throw ch.ExceptUserArg(nameof(termsColumn), "Unknown column '{0}'", src);
var typeSrc = termData.Schema[colSrc].Type;
// Now, remove everything but that one column.
var selectTransformer = new ColumnSelectingTransformer(env, new string[] { src }, null);
keyData = selectTransformer.Transform(keyData);
ch.Assert(keyData.Schema.Count == 1);
return keyData;
}

/// <summary>
/// Utility method to create the file-based <see cref="TermMap"/>.
/// </summary>
private static TermMap CreateTermMapFromData(IHostEnvironment env, IChannel ch, IDataView keyData, bool autoConvert, Builder bldr)
{
Contracts.AssertValue(ch);
ch.AssertValue(env);
ch.AssertValue(keyData);
ch.AssertValue(bldr);
if (keyData.Schema.Count != 1)
{
throw ch.ExceptParam(nameof(keyData), $"Input data containing terms should contain exactly one column, but " +
$"had {keyData.Schema.Count} instead. Consider using {nameof(ColumnSelectingEstimator)} on that data first.");
}

var typeSrc = keyData.Schema[0].Type;
if (!autoConvert && !typeSrc.Equals(bldr.ItemType))
throw ch.ExceptUserArg(nameof(termsColumn), "Must be of type '{0}' but was '{1}'", bldr.ItemType, typeSrc);
throw ch.ExceptUserArg(nameof(keyData), "Input data's column must be of type '{0}' but was '{1}'", bldr.ItemType, typeSrc);

using (var cursor = termData.GetRowCursor(termData.Schema[colSrc]))
using (var pch = env.StartProgressChannel("Building term dictionary from file"))
using (var cursor = keyData.GetRowCursor(keyData.Schema[0]))
using (var pch = env.StartProgressChannel("Building dictionary from term data"))
{
var header = new ProgressHeader(new[] { "Total Terms" }, new[] { "examples" });
var trainer = Trainer.Create(cursor, colSrc, autoConvert, int.MaxValue, bldr);
double rowCount = termData.GetRowCount() ?? double.NaN;
var trainer = Trainer.Create(cursor, 0, autoConvert, int.MaxValue, bldr);
double rowCount = keyData.GetRowCount() ?? double.NaN;
long rowCur = 0;
pch.SetHeader(header,
e =>
@@ -501,7 +534,7 @@ private static TermMap CreateFileTermMap(IHostEnvironment env, IChannel ch, stri
while (cursor.MoveNext() && trainer.ProcessRow())
rowCur++;
if (trainer.Count == 0)
ch.Warning("Term map loaded from file resulted in an empty map.");
ch.Warning("Map from the term data resulted in an empty map.");
pch.Checkpoint(trainer.Count, rowCur);
return trainer.Finish();
}
@@ -511,12 +544,12 @@ private static TermMap CreateFileTermMap(IHostEnvironment env, IChannel ch, stri
/// This builds the <see cref="TermMap"/> instances per column.
/// </summary>
private static TermMap[] Train(IHostEnvironment env, IChannel ch, ColInfo[] infos,
string file, string termsColumn,
IComponentFactory<IMultiStreamSource, IDataLoader> loaderFactory, ColumnInfo[] columns, IDataView trainingData)
IDataView keyData, ColumnInfo[] columns, IDataView trainingData, bool autoConvert)
{
Contracts.AssertValue(env);
env.AssertValue(ch);
ch.AssertValue(infos);
ch.AssertValueOrNull(keyData);
ch.AssertValue(columns);
ch.AssertValue(trainingData);

@@ -544,13 +577,13 @@ private static TermMap[] Train(IHostEnvironment env, IChannel ch, ColInfo[] info
bldr.ParseAddTermArg(termsArray, ch);
termMap[iinfo] = bldr.Finish();
}
else if (!string.IsNullOrWhiteSpace(file))
else if (keyData != null)
{
// First column using this file.
if (termsFromFile == null)
{
var bldr = Builder.Create(infos[iinfo].TypeSrc, columns[iinfo].Sort);
termsFromFile = CreateFileTermMap(env, ch, file, termsColumn, loaderFactory, bldr);
termsFromFile = CreateTermMapFromData(env, ch, keyData, autoConvert, bldr);
}
if (!termsFromFile.ItemType.Equals(infos[iinfo].TypeSrc.GetItemType()))
{
@@ -559,7 +592,7 @@ private static TermMap[] Train(IHostEnvironment env, IChannel ch, ColInfo[] info
// a complicated feature would be, and also because it's difficult to see how we
// can logically reconcile "reinterpretation" for different types with the resulting
// data view having an actual type.
throw ch.ExceptUserArg(nameof(file), "Data file terms loaded as type '{0}' but mismatches column '{1}' item type '{2}'",
throw ch.ExceptParam(nameof(keyData), "Terms from input data type '{0}' but mismatches column '{1}' item type '{2}'",
termsFromFile.ItemType, infos[iinfo].Name, infos[iinfo].TypeSrc.GetItemType());
}
termMap[iinfo] = termsFromFile;
16 changes: 11 additions & 5 deletions src/Microsoft.ML.Transforms/OneHotEncoding.cs
Original file line number Diff line number Diff line change
@@ -139,7 +139,15 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat
col.SetTerms(column.Terms ?? args.Terms);
columns.Add(col);
}
return new OneHotEncodingEstimator(env, columns.ToArray(), args.DataFile, args.TermsColumn, args.Loader).Fit(input).Transform(input) as IDataTransform;
IDataView keyData = null;
if (!string.IsNullOrEmpty(args.DataFile))
{
using (var ch = h.Start("Load term data"))
keyData = ValueToKeyMappingTransformer.GetKeyDataViewOrNull(env, ch, args.DataFile, args.TermsColumn, args.Loader, out bool autoLoaded);
h.AssertValue(keyData);
}
var transformed = new OneHotEncodingEstimator(env, columns.ToArray(), keyData).Fit(input).Transform(input);
return (IDataTransform)transformed;
}

private readonly TransformerChain<ITransformer> _transformer;
@@ -220,13 +228,11 @@ public OneHotEncodingEstimator(IHostEnvironment env, string inputColumn,
{
}

public OneHotEncodingEstimator(IHostEnvironment env, ColumnInfo[] columns,
string file = null, string termsColumn = null,
IComponentFactory<IMultiStreamSource, IDataLoader> loaderFactory = null)
public OneHotEncodingEstimator(IHostEnvironment env, ColumnInfo[] columns, IDataView keyData = null)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(OneHotEncodingEstimator));
_term = new ValueToKeyMappingEstimator(_host, columns, file, termsColumn, loaderFactory);
_term = new ValueToKeyMappingEstimator(_host, columns, keyData);
var binaryCols = new List<(string input, string output)>();
var cols = new List<(string input, string output, bool bag)>();
for (int i = 0; i < columns.Length; i++)
50 changes: 25 additions & 25 deletions src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs
Original file line number Diff line number Diff line change
@@ -33,20 +33,20 @@
[assembly: LoadableClass(typeof(IRowMapper), typeof(StopWordsRemovingTransformer), null, typeof(SignatureLoadRowMapper),
"Stopwords Remover Transform", StopWordsRemovingTransformer.LoaderSignature)]

[assembly: LoadableClass(CustomStopWordsRemovingTransform.Summary, typeof(IDataTransform), typeof(CustomStopWordsRemovingTransform), typeof(CustomStopWordsRemovingTransform.Arguments), typeof(SignatureDataTransform),
[assembly: LoadableClass(CustomStopWordsRemovingTransformer.Summary, typeof(IDataTransform), typeof(CustomStopWordsRemovingTransformer), typeof(CustomStopWordsRemovingTransformer.Arguments), typeof(SignatureDataTransform),
"Custom Stopwords Remover Transform", "CustomStopWordsRemoverTransform", "CustomStopWords")]

[assembly: LoadableClass(CustomStopWordsRemovingTransform.Summary, typeof(IDataTransform), typeof(CustomStopWordsRemovingTransform), null, typeof(SignatureLoadDataTransform),
"Custom Stopwords Remover Transform", CustomStopWordsRemovingTransform.LoaderSignature)]
[assembly: LoadableClass(CustomStopWordsRemovingTransformer.Summary, typeof(IDataTransform), typeof(CustomStopWordsRemovingTransformer), null, typeof(SignatureLoadDataTransform),
"Custom Stopwords Remover Transform", CustomStopWordsRemovingTransformer.LoaderSignature)]

[assembly: LoadableClass(CustomStopWordsRemovingTransform.Summary, typeof(CustomStopWordsRemovingTransform), null, typeof(SignatureLoadModel),
"Custom Stopwords Remover Transform", CustomStopWordsRemovingTransform.LoaderSignature)]
[assembly: LoadableClass(CustomStopWordsRemovingTransformer.Summary, typeof(CustomStopWordsRemovingTransformer), null, typeof(SignatureLoadModel),
"Custom Stopwords Remover Transform", CustomStopWordsRemovingTransformer.LoaderSignature)]

[assembly: LoadableClass(typeof(IRowMapper), typeof(CustomStopWordsRemovingTransform), null, typeof(SignatureLoadRowMapper),
"Custom Stopwords Remover Transform", CustomStopWordsRemovingTransform.LoaderSignature)]
[assembly: LoadableClass(typeof(IRowMapper), typeof(CustomStopWordsRemovingTransformer), null, typeof(SignatureLoadRowMapper),
"Custom Stopwords Remover Transform", CustomStopWordsRemovingTransformer.LoaderSignature)]

[assembly: EntryPointModule(typeof(PredefinedStopWordsRemoverFactory))]
[assembly: EntryPointModule(typeof(CustomStopWordsRemovingTransform.LoaderArguments))]
[assembly: EntryPointModule(typeof(CustomStopWordsRemovingTransformer.LoaderArguments))]

namespace Microsoft.ML.Transforms.Text
{
@@ -596,7 +596,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
/// This is usually applied after tokenizing text, so it compares individual tokens
/// (case-insensitive comparison) to the stopwords.
/// </summary>
public sealed class CustomStopWordsRemovingTransform : OneToOneTransformerBase
public sealed class CustomStopWordsRemovingTransformer : OneToOneTransformerBase
{
public sealed class Column : OneToOneColumn
{
@@ -646,9 +646,9 @@ public sealed class LoaderArguments : ArgumentsBase, IStopWordsRemoverFactory
public IDataTransform CreateComponent(IHostEnvironment env, IDataView input, OneToOneColumn[] column)
{
if (Utils.Size(Stopword) > 0)
return new CustomStopWordsRemovingTransform(env, Stopword, column.Select(x => (x.Source, x.Name)).ToArray()).Transform(input) as IDataTransform;
return new CustomStopWordsRemovingTransformer(env, Stopword, column.Select(x => (x.Source, x.Name)).ToArray()).Transform(input) as IDataTransform;
else
return new CustomStopWordsRemovingTransform(env, Stopwords, DataFile, StopwordsColumn, Loader, column.Select(x => (x.Source, x.Name)).ToArray()).Transform(input) as IDataTransform;
return new CustomStopWordsRemovingTransformer(env, Stopwords, DataFile, StopwordsColumn, Loader, column.Select(x => (x.Source, x.Name)).ToArray()).Transform(input) as IDataTransform;
}
}

@@ -665,7 +665,7 @@ private static VersionInfo GetVersionInfo()
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(CustomStopWordsRemovingTransform).Assembly.FullName);
loaderAssemblyName: typeof(CustomStopWordsRemovingTransformer).Assembly.FullName);
}

private const string StopwordsManagerLoaderSignature = "CustomStopWordsManager";
@@ -678,7 +678,7 @@ private static VersionInfo GetStopwordsManagerVersionInfo()
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: StopwordsManagerLoaderSignature,
loaderAssemblyName: typeof(CustomStopWordsRemovingTransform).Assembly.FullName);
loaderAssemblyName: typeof(CustomStopWordsRemovingTransformer).Assembly.FullName);
}

private static readonly ColumnType _outputType = new VectorType(TextType.Instance);
@@ -808,7 +808,7 @@ private void LoadStopWords(IChannel ch, ReadOnlyMemory<char> stopwords, string d
/// <param name="env">The environment.</param>
/// <param name="stopwords">Array of words to remove.</param>
/// <param name="columns">Pairs of columns to remove stop words from.</param>
public CustomStopWordsRemovingTransform(IHostEnvironment env, string[] stopwords, params (string input, string output)[] columns) :
public CustomStopWordsRemovingTransformer(IHostEnvironment env, string[] stopwords, params (string input, string output)[] columns) :
base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), columns)
{
_stopWordsMap = new NormStr.Pool();
@@ -826,7 +826,7 @@ public CustomStopWordsRemovingTransform(IHostEnvironment env, string[] stopwords
}
}

internal CustomStopWordsRemovingTransform(IHostEnvironment env, string stopwords,
internal CustomStopWordsRemovingTransformer(IHostEnvironment env, string stopwords,
string dataFile, string stopwordsColumn, IComponentFactory<IMultiStreamSource, IDataLoader> loader, params (string input, string output)[] columns) :
base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), columns)
{
@@ -874,7 +874,7 @@ public override void Save(ModelSaveContext ctx)
});
}

private CustomStopWordsRemovingTransform(IHost host, ModelLoadContext ctx) :
private CustomStopWordsRemovingTransformer(IHost host, ModelLoadContext ctx) :
base(host, ctx)
{
var columnsLength = ColumnPairs.Length;
@@ -919,13 +919,13 @@ private CustomStopWordsRemovingTransform(IHost host, ModelLoadContext ctx) :
}

// Factory method for SignatureLoadModel.
private static CustomStopWordsRemovingTransform Create(IHostEnvironment env, ModelLoadContext ctx)
private static CustomStopWordsRemovingTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register(RegistrationName);
host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
return new CustomStopWordsRemovingTransform(host, ctx);
return new CustomStopWordsRemovingTransformer(host, ctx);
}

// Factory method for SignatureDataTransform.
@@ -942,11 +942,11 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat
var item = args.Column[i];
cols[i] = (item.Source ?? item.Name, item.Name);
}
CustomStopWordsRemovingTransform transfrom = null;
CustomStopWordsRemovingTransformer transfrom = null;
if (Utils.Size(args.Stopword) > 0)
transfrom = new CustomStopWordsRemovingTransform(env, args.Stopword, cols);
transfrom = new CustomStopWordsRemovingTransformer(env, args.Stopword, cols);
else
transfrom = new CustomStopWordsRemovingTransform(env, args.Stopwords, args.DataFile, args.StopwordsColumn, args.Loader, cols);
transfrom = new CustomStopWordsRemovingTransformer(env, args.Stopwords, args.DataFile, args.StopwordsColumn, args.Loader, cols);
return transfrom.MakeDataTransform(input);
}

@@ -963,9 +963,9 @@ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Sch
private sealed class Mapper : OneToOneMapperBase
{
private readonly ColumnType[] _types;
private readonly CustomStopWordsRemovingTransform _parent;
private readonly CustomStopWordsRemovingTransformer _parent;

public Mapper(CustomStopWordsRemovingTransform parent, Schema inputSchema)
public Mapper(CustomStopWordsRemovingTransformer parent, Schema inputSchema)
: base(Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(Mapper)), parent, inputSchema)
{
_parent = parent;
@@ -1036,7 +1036,7 @@ protected override Delegate MakeGetter(Row input, int iinfo, Func<int, bool> act
/// This is usually applied after tokenizing text, so it compares individual tokens
/// (case-insensitive comparison) to the stopwords.
/// </summary>
public sealed class CustomStopWordsRemovingEstimator : TrivialEstimator<CustomStopWordsRemovingTransform>
public sealed class CustomStopWordsRemovingEstimator : TrivialEstimator<CustomStopWordsRemovingTransformer>
{
internal const string ExpectedColumnType = "vector of Text type";

@@ -1061,7 +1061,7 @@ public CustomStopWordsRemovingEstimator(IHostEnvironment env, string inputColumn
/// <param name="columns">Pairs of columns to remove stop words on.</param>
/// <param name="stopwords">Array of words to remove.</param>
public CustomStopWordsRemovingEstimator(IHostEnvironment env, (string input, string output)[] columns, string[] stopwords) :
base(Contracts.CheckRef(env, nameof(env)).Register(nameof(CustomStopWordsRemovingEstimator)), new CustomStopWordsRemovingTransform(env, stopwords, columns))
base(Contracts.CheckRef(env, nameof(env)).Register(nameof(CustomStopWordsRemovingEstimator)), new CustomStopWordsRemovingTransformer(env, stopwords, columns))
{
}

42 changes: 40 additions & 2 deletions test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs
Original file line number Diff line number Diff line change
@@ -24,14 +24,14 @@ public CategoricalTests(ITestOutputHelper output) : base(output)
{
}

private class TestClass
private sealed class TestClass
{
public int A;
public int B;
public int C;
}

private class TestMeta
private sealed class TestMeta
{
[VectorType(2)]
public string[] A;
@@ -47,6 +47,11 @@ private class TestMeta
public string H;
}

private sealed class TestStringClass
{
public string A;
}

[Fact]
public void CategoricalWorkout()
{
@@ -98,6 +103,39 @@ public void CategoricalOneHotEncoding()
Done();
}

/// <summary>
/// In which we take a categorical value and map it to a vector, but we get the mapping from a side data view
/// rather than the data we are fitting.
/// </summary>
[Fact]
public void CategoricalOneHotEncodingFromSideData()
{
// In this case, whatever the value of the input, the term mapping should come from the optional side data if specified.
var data = new[] { new TestStringClass() { A = "Stay" }, new TestStringClass() { A = "awhile and listen" } };

var mlContext = new MLContext();
var dataView = mlContext.Data.ReadFromEnumerable(data);

var sideDataBuilder = new ArrayDataViewBuilder(mlContext);
sideDataBuilder.AddColumn("Hello", "hello", "my", "friend");
var sideData = sideDataBuilder.GetDataView();

var ci = new OneHotEncodingEstimator.ColumnInfo("A", "CatA", OneHotEncodingTransformer.OutputKind.Bag);
var pipe = new OneHotEncodingEstimator(mlContext, new[] { ci }, sideData);

var output = pipe.Fit(dataView).Transform(dataView);

VBuffer<ReadOnlyMemory<char>> slotNames = default;
output.Schema["CatA"].GetSlotNames(ref slotNames);

Assert.Equal(3, slotNames.Length);
Assert.Equal("hello", slotNames.GetItemOrDefault(0).ToString());
Assert.Equal("my", slotNames.GetItemOrDefault(1).ToString());
Assert.Equal("friend", slotNames.GetItemOrDefault(2).ToString());

Done();
}

[Fact]
public void CategoricalStatic()
{
45 changes: 41 additions & 4 deletions test/Microsoft.ML.Tests/Transformers/ConvertTests.cs
Original file line number Diff line number Diff line change
@@ -25,7 +25,7 @@ public ConvertTests(ITestOutputHelper output) : base(output)
{
}

private class TestPrimitiveClass
private sealed class TestPrimitiveClass
{
[VectorType(2)]
public string[] AA;
@@ -53,20 +53,23 @@ private class TestPrimitiveClass
public double[] AN;
}

private class TestClass
private sealed class TestClass
{
public int A;
[VectorType(2)]
public int[] B;
}

public class MetaClass
private sealed class MetaClass
{
public float A;
public string B;

}

private sealed class TestStringClass
{
public string A;
}

[Fact]
public void TestConvertWorkout()
@@ -142,6 +145,40 @@ public void TestConvertWorkout()
Done();
}

/// <summary>
/// Apply <see cref="KeyToValueMappingEstimator"/> with side data.
/// </summary>
[Fact]
public void ValueToKeyFromSideData()
{
// In this case, whatever the value of the input, the term mapping should come from the optional side data if specified.
var data = new[] { new TestStringClass() { A = "Stay" }, new TestStringClass() { A = "awhile and listen" } };

var mlContext = new MLContext();
var dataView = mlContext.Data.ReadFromEnumerable(data);

var sideDataBuilder = new ArrayDataViewBuilder(mlContext);
sideDataBuilder.AddColumn("Hello", "hello", "my", "friend");
var sideData = sideDataBuilder.GetDataView();

// For some reason the column info is on the *transformer*, not the estimator. Already tracked as issue #1760.
var ci = new ValueToKeyMappingTransformer.ColumnInfo("A", "CatA");
var pipe = mlContext.Transforms.Conversion.MapValueToKey(new[] { ci }, sideData);
var output = pipe.Fit(dataView).Transform(dataView);

VBuffer<ReadOnlyMemory<char>> slotNames = default;
output.Schema["CatA"].Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref slotNames);

Assert.Equal(3, slotNames.Length);
Assert.Equal("hello", slotNames.GetItemOrDefault(0).ToString());
Assert.Equal("my", slotNames.GetItemOrDefault(1).ToString());
Assert.Equal("friend", slotNames.GetItemOrDefault(2).ToString());

Done();
}



[Fact]
public void TestCommandLine()
{