Skip to content

Normalization API helpers #446

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 3, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/Microsoft.ML.Core/Data/MetadataUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,22 @@ public static bool HasKeyNames(this ISchema schema, int col, int keyCount)
&& type.ItemType.IsText;
}

/// <summary>
/// Returns whether a column has the <see cref="Kinds.IsNormalized"/> metadata set to true.
Copy link
Member

@eerhardt eerhardt Jul 2, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(minor) Do we instead want to document what IsNormalized means? "Returns whether a column <insert what IsNormalized means>". #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I can add a bit more information. Though that information, I might prefer to make that part of the Kinds static class, since the documentation of those are the primary source of truth. This is just something we added as a convenience on top of that.


In reply to: 199596084 [](ancestors = 199596084)

/// That metadata should be set when the data has undergone transforms that would render it
/// "normalized."
/// </summary>
/// <param name="schema">The schema to query</param>
/// <param name="col">Which column in the schema to query</param>
/// <returns>True if and only if the column has the <see cref="Kinds.IsNormalized"/> metadata
/// set to the scalar value <see cref="DvBool.True"/></returns>
public static bool IsNormalized(this ISchema schema, int col)
{
Contracts.CheckValue(schema, nameof(schema));
var value = default(DvBool);
return schema.TryGetMetadata(BoolType.Instance, Kinds.IsNormalized, col, ref value) && value.IsTrue;
}

/// <summary>
/// Tries to get the metadata kind of the specified type for a column.
/// </summary>
Expand All @@ -347,6 +363,9 @@ public static bool HasKeyNames(this ISchema schema, int col, int keyCount)
/// <returns>True if the metadata of the right type exists, false otherwise</returns>
public static bool TryGetMetadata<T>(this ISchema schema, PrimitiveType type, string kind, int col, ref T value)
{
Contracts.CheckValue(schema, nameof(schema));
Contracts.CheckValue(type, nameof(type));

var metadataType = schema.GetMetadataTypeOrNull(kind, col);
if (!type.Equals(metadataType))
return false;
Expand All @@ -363,7 +382,7 @@ public static bool IsHidden(this ISchema schema, int col)
string name = schema.GetColumnName(col);
int top;
bool tmp = schema.TryGetColumnIndex(name, out top);
Contracts.Assert(tmp, "Why did TryGetColumnIndex return false?");
Contracts.Assert(tmp); // This would only be false if the implementation of schema were buggy.
return !tmp || top != col;
}

Expand Down
70 changes: 39 additions & 31 deletions src/Microsoft.ML.Core/Data/RoleMappedSchema.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
namespace Microsoft.ML.Runtime.Data
{
/// <summary>
/// This contains information about a column in an IDataView. It is essentially a convenience
/// cache containing the name, column index, and column type for the column.
/// This contains information about a column in an <see cref="IDataView"/>. It is essentially a convenience cache
/// containing the name, column index, and column type for the column. The intended usage is that users of <see cref="RoleMappedSchema"/>
/// will have a convenient method of getting the index and type without having to separately query it through the <see cref="ISchema"/>,
/// since practically the first thing a consumer of a <see cref="RoleMappedSchema"/> will want to do once they get a mappping is get
/// the type and index of the corresponding column.
/// </summary>
public sealed class ColumnInfo
{
Expand Down Expand Up @@ -71,12 +74,20 @@ public static ColumnInfo CreateFromIndex(ISchema schema, int index)
}

/// <summary>
/// Encapsulates an ISchema plus column role mapping information. It has convenience fields for
/// several common column roles, but can hold an arbitrary set of column infos. The convenience
/// fields are non-null iff there is a unique column with the corresponding role. When there are
/// no such columns or more than one such column, the field is null. The Has, HasUnique, and
/// HasMultiple methods provide some cardinality information.
/// Note that all columns assigned roles are guaranteed to be non-hidden in this schema.
/// Encapsulates an <see cref="ISchema"/> plus column role mapping information. The purpose of role mappings is to
/// provide information on what the intended usage is for. That is: while a given data view may have a column named
/// "Features", by itself that is insufficient: the trainer must be fed a role mapping that says that the role
/// mapping for features is filled by that "Features" column. This allows things like columns not named "Features"
/// to actually fill that role (as opposed to insisting on a hard coding, or having every trainer have to be
/// individually configured). Also, by being a one-to-many mapping, it is a way for learners that can consume
/// multiple features columns to consume that information.
///
/// This class has convenience fields for several common column roles (se.g., <see cref="Feature"/>, <see
/// cref="Label"/>), but can hold an arbitrary set of column infos. The convenience fields are non-null iff there is
/// a unique column with the corresponding role. When there are no such columns or more than one such column, the
/// field is null. The <see cref="Has"/>, <see cref="HasUnique"/>, and <see cref="HasMultiple"/> methods provide
/// some cardinality information. Note that all columns assigned roles are guaranteed to be non-hidden in this
/// schema.
/// </summary>
public sealed class RoleMappedSchema
{
Expand All @@ -85,18 +96,16 @@ public sealed class RoleMappedSchema
private const string GroupString = "Group";
private const string WeightString = "Weight";
private const string NameString = "Name";
private const string IdString = "Id";
private const string FeatureContributionsString = "FeatureContributions";

public struct ColumnRole
{
public static ColumnRole Feature { get { return new ColumnRole(FeatureString); } }
public static ColumnRole Label { get { return new ColumnRole(LabelString); } }
public static ColumnRole Group { get { return new ColumnRole(GroupString); } }
public static ColumnRole Weight { get { return new ColumnRole(WeightString); } }
public static ColumnRole Name { get { return new ColumnRole(NameString); } }
public static ColumnRole Id { get { return new ColumnRole(IdString); } }
public static ColumnRole FeatureContributions { get { return new ColumnRole(FeatureContributionsString); } }
public static ColumnRole Feature => FeatureString;
public static ColumnRole Label => LabelString;
public static ColumnRole Group => GroupString;
public static ColumnRole Weight => WeightString;
public static ColumnRole Name => NameString;
public static ColumnRole FeatureContributions => FeatureContributionsString;

public readonly string Value;

Expand Down Expand Up @@ -152,11 +161,6 @@ public static KeyValuePair<ColumnRole, string> CreatePair(ColumnRole role, strin
/// </summary>
public readonly ColumnInfo Name;

/// <summary>
/// The Id column, when there is exactly one (null otherwise).
/// </summary>
public readonly ColumnInfo Id;

// Maps from role to the associated column infos.
private readonly Dictionary<string, IReadOnlyList<ColumnInfo>> _map;

Expand Down Expand Up @@ -194,9 +198,6 @@ private RoleMappedSchema(ISchema schema, Dictionary<string, IReadOnlyList<Column
case NameString:
Name = cols[0];
break;
case IdString:
Id = cols[0];
break;
}
}
}
Expand Down Expand Up @@ -224,8 +225,8 @@ private static void Add(Dictionary<string, List<ColumnInfo>> map, ColumnRole rol

private static Dictionary<string, List<ColumnInfo>> MapFromNames(ISchema schema, IEnumerable<KeyValuePair<ColumnRole, string>> roles)
{
Contracts.AssertValue(schema, "schema");
Contracts.AssertValue(roles, "roles");
Contracts.AssertValue(schema);
Contracts.AssertValue(roles);

var map = new Dictionary<string, List<ColumnInfo>>();
foreach (var kvp in roles)
Expand All @@ -241,8 +242,8 @@ private static Dictionary<string, List<ColumnInfo>> MapFromNames(ISchema schema,

private static Dictionary<string, List<ColumnInfo>> MapFromNamesOpt(ISchema schema, IEnumerable<KeyValuePair<ColumnRole, string>> roles)
{
Contracts.AssertValue(schema, "schema");
Contracts.AssertValue(roles, "roles");
Contracts.AssertValue(schema);
Contracts.AssertValue(roles);

var map = new Dictionary<string, List<ColumnInfo>>();
foreach (var kvp in roles)
Expand Down Expand Up @@ -334,6 +335,13 @@ public IEnumerable<KeyValuePair<ColumnRole, string>> GetColumnRoleNames(ColumnRo
}
}

/// <summary>
/// Returns the <see cref="ColumnInfo"/> corresponding to <paramref name="role"/> if there is
/// exactly one such mapping, and otherwise throws an exception.
/// </summary>
/// <param name="role">The role to look up</param>
/// <returns>The info corresponding to that role, assuming there was only one column
/// mapped to that</returns>
public ColumnInfo GetUniqueColumn(ColumnRole role)
{
var infos = GetColumns(role);
Expand Down Expand Up @@ -398,9 +406,9 @@ public static RoleMappedSchema CreateOpt(ISchema schema, IEnumerable<KeyValuePai
}

/// <summary>
/// Encapsulates an IDataView plus a corresponding RoleMappedSchema. Note that the schema of the
/// RoleMappedSchema is guaranteed to be the same schema of the IDataView, that is,
/// Data.Schema == Schema.Schema.
/// Encapsulates an <see cref="IDataView"/> plus a corresponding <see cref="RoleMappedSchema"/>.
/// Note that the schema of <see cref="RoleMappedSchema.Schema"/> of <see cref="Schema"/> is
/// guaranteed to equal the the <see cref="ISchematized.Schema"/> of <see cref="Data"/>.
/// </summary>
public sealed class RoleMappedData
{
Expand Down
31 changes: 8 additions & 23 deletions src/Microsoft.ML.Data/Commands/TrainCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -468,16 +468,11 @@ private static List<IDataTransform> BacktrackPipe(IDataView dataPipe, out IDataV
Contracts.AssertValue(dataPipe);

var transforms = new List<IDataTransform>();
while (true)
while (dataPipe is IDataTransform xf)
{
// REVIEW: a malicious user could construct a loop in the Source chain, that would
// cause this method to iterate forever (and throw something when the list overflows). There's
// no way to insulate from ALL malicious behavior.

var xf = dataPipe as IDataTransform;
if (xf == null)
break;

transforms.Add(xf);
dataPipe = xf.Source;
Contracts.AssertValue(dataPipe);
Expand Down Expand Up @@ -514,11 +509,8 @@ public static bool AddNormalizerIfNeeded(IHostEnvironment env, IChannel ch, ITra
{
if (autoNorm != NormalizeOption.Yes)
{
var nn = trainer as ITrainerEx;
DvBool isNormalized = DvBool.False;
if (nn == null || !nn.NeedNormalization ||
(schema.TryGetMetadata(BoolType.Instance, MetadataUtils.Kinds.IsNormalized, featCol, ref isNormalized) &&
isNormalized.IsTrue))
if (trainer.NeedNormalization() != true || schema.IsNormalized(featCol))
{
ch.Info("Not adding a normalizer.");
return false;
Expand All @@ -530,20 +522,13 @@ public static bool AddNormalizerIfNeeded(IHostEnvironment env, IChannel ch, ITra
}
}
ch.Info("Automatically adding a MinMax normalization transform, use 'norm=Warn' or 'norm=No' to turn this behavior off.");
// Quote the feature column name
string quotedFeatureColumnName = featureColumn;
StringBuilder sb = new StringBuilder();
if (CmdQuoter.QuoteValue(quotedFeatureColumnName, sb))
quotedFeatureColumnName = sb.ToString();
var component = new SubComponent<IDataTransform, SignatureDataTransform>("MinMax", string.Format("col={{ name={0} source={0} }}", quotedFeatureColumnName));
var loader = view as IDataLoader;
if (loader != null)
{
view = CompositeDataLoader.Create(env, loader,
new KeyValuePair<string, SubComponent<IDataTransform, SignatureDataTransform>>(null, component));
}
IDataView ApplyNormalizer(IHostEnvironment innerEnv, IDataView input)
Copy link
Contributor

@zeahmed zeahmed Jun 28, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If my PR goes in first, you will be able to use convenience constructor here...:) #Closed

Copy link
Contributor Author

@TomFinley TomFinley Jun 28, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know, I even had a comment about that. :) But Ivan made me remove it. #Closed

=> NormalizeTransform.CreateMinMaxNormalizer(innerEnv, input, featureColumn);

if (view is IDataLoader loader)
view = CompositeDataLoader.ApplyTransform(env, loader, tag: null, creationArgs: null, ApplyNormalizer);
else
view = component.CreateInstance(env, view);
view = ApplyNormalizer(env, view);
return true;
}
return false;
Expand Down
37 changes: 13 additions & 24 deletions src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public sealed class Arguments
public KeyValuePair<string, SubComponent<IDataTransform, SignatureDataTransform>>[] Transform;
}

internal struct TransformEx
private struct TransformEx
{
public readonly string Tag;
public readonly string ArgsString;
Expand Down Expand Up @@ -78,16 +78,14 @@ private static VersionInfo GetVersionInfo()
// The composition of loader plus transforms in order.
private readonly IDataLoader _loader;
private readonly TransformEx[] _transforms;
private readonly IDataView _view;
private readonly ITransposeDataView _tview;
private readonly ITransposeSchema _tschema;
private readonly IHost _host;

/// <summary>
/// Returns the underlying data view of the composite loader.
/// This can be used to programmatically explore the chain of transforms that's inside the composite loader.
/// </summary>
internal IDataView View { get { return _view; } }
internal IDataView View { get; }

/// <summary>
/// Creates a loader according to the specified <paramref name="args"/>.
Expand Down Expand Up @@ -200,7 +198,7 @@ private static IDataLoader ApplyTransformsCore(IHost host, IDataLoader srcLoader
IDataLoader pipeStart;
if (composite != null)
{
srcView = composite._view;
srcView = composite.View;
exes.AddRange(composite._transforms);
pipeStart = composite._loader;
}
Expand Down Expand Up @@ -409,9 +407,9 @@ private CompositeDataLoader(IHost host, TransformEx[] transforms)
_host = host;
_host.AssertNonEmpty(transforms);

_view = transforms[transforms.Length - 1].Transform;
_tview = _view as ITransposeDataView;
_tschema = _tview == null ? new TransposerUtils.SimpleTransposeSchema(_view.Schema) : _tview.TransposeSchema;
View = transforms[transforms.Length - 1].Transform;
_tview = View as ITransposeDataView;
TransposeSchema = _tview?.TransposeSchema ?? new TransposerUtils.SimpleTransposeSchema(View.Schema);

var srcLoader = transforms[0].Transform.Source as IDataLoader;

Expand Down Expand Up @@ -561,43 +559,34 @@ private static string GenerateTag(int index)

public long? GetRowCount(bool lazy = true)
{
return _view.GetRowCount(lazy);
return View.GetRowCount(lazy);
}

public bool CanShuffle
{
get { return _view.CanShuffle; }
}
public bool CanShuffle => View.CanShuffle;

public ISchema Schema
{
get { return _view.Schema; }
}
public ISchema Schema => View.Schema;

public ITransposeSchema TransposeSchema
{
get { return _tschema; }
}
public ITransposeSchema TransposeSchema { get; }

public IRowCursor GetRowCursor(Func<int, bool> predicate, IRandom rand = null)
{
_host.CheckValue(predicate, nameof(predicate));
_host.CheckValueOrNull(rand);
return _view.GetRowCursor(predicate, rand);
return View.GetRowCursor(predicate, rand);
}

public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator,
Func<int, bool> predicate, int n, IRandom rand = null)
{
_host.CheckValue(predicate, nameof(predicate));
_host.CheckValueOrNull(rand);
return _view.GetRowCursorSet(out consolidator, predicate, n, rand);
return View.GetRowCursorSet(out consolidator, predicate, n, rand);
}

public ISlotCursor GetSlotCursor(int col)
{
_host.CheckParam(0 <= col && col < Schema.ColumnCount, nameof(col));
if (_tschema == null || _tschema.GetSlotType(col) == null)
if (TransposeSchema?.GetSlotType(col) == null)
{
throw _host.ExceptParam(nameof(col), "Bad call to GetSlotCursor on untransposable column '{0}'",
Schema.GetColumnName(col));
Expand Down
9 changes: 2 additions & 7 deletions src/Microsoft.ML.Data/Transforms/ConcatTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,8 @@ private void CacheTypes(out ColumnType[] types, out ColumnType[] typesSlotNames,
{
// All meta-data is passed through in this case, so don't need the slot names type.
echoSrc[i] = true;
DvBool b = DvBool.False;
isNormalized[i] =
info.SrcTypes[0].ItemType.IsNumber &&
Input.TryGetMetadata(BoolType.Instance, MetadataUtils.Kinds.IsNormalized, info.SrcIndices[0], ref b) &&
b.IsTrue;
info.SrcTypes[0].ItemType.IsNumber && Input.IsNormalized(info.SrcIndices[0]);
types[i] = info.SrcTypes[0];
continue;
}
Expand All @@ -260,9 +257,7 @@ private void CacheTypes(out ColumnType[] types, out ColumnType[] typesSlotNames,
{
foreach (var srcCol in info.SrcIndices)
{
DvBool b = DvBool.False;
if (!Input.TryGetMetadata(BoolType.Instance, MetadataUtils.Kinds.IsNormalized, srcCol, ref b) ||
!b.IsTrue)
if (!Input.IsNormalized(srcCol))
{
isNormalized[i] = false;
break;
Expand Down
Loading