diff --git a/src/Microsoft.ML.Core/Data/MetadataUtils.cs b/src/Microsoft.ML.Core/Data/MetadataUtils.cs
index 04f31d844a..ca13d03ab3 100644
--- a/src/Microsoft.ML.Core/Data/MetadataUtils.cs
+++ b/src/Microsoft.ML.Core/Data/MetadataUtils.cs
@@ -335,6 +335,22 @@ public static bool HasKeyNames(this ISchema schema, int col, int keyCount)
&& type.ItemType.IsText;
}
+ ///
+ /// Returns whether a column has the metadata set to true.
+ /// That metadata should be set when the data has undergone transforms that would render it
+ /// "normalized."
+ ///
+ /// The schema to query
+ /// Which column in the schema to query
+ /// True if and only if the column has the metadata
+ /// set to the scalar value
+ public static bool IsNormalized(this ISchema schema, int col)
+ {
+ Contracts.CheckValue(schema, nameof(schema));
+ var value = default(DvBool);
+ return schema.TryGetMetadata(BoolType.Instance, Kinds.IsNormalized, col, ref value) && value.IsTrue;
+ }
+
///
/// Tries to get the metadata kind of the specified type for a column.
///
@@ -347,6 +363,9 @@ public static bool HasKeyNames(this ISchema schema, int col, int keyCount)
/// True if the metadata of the right type exists, false otherwise
public static bool TryGetMetadata(this ISchema schema, PrimitiveType type, string kind, int col, ref T value)
{
+ Contracts.CheckValue(schema, nameof(schema));
+ Contracts.CheckValue(type, nameof(type));
+
var metadataType = schema.GetMetadataTypeOrNull(kind, col);
if (!type.Equals(metadataType))
return false;
@@ -363,7 +382,7 @@ public static bool IsHidden(this ISchema schema, int col)
string name = schema.GetColumnName(col);
int top;
bool tmp = schema.TryGetColumnIndex(name, out top);
- Contracts.Assert(tmp, "Why did TryGetColumnIndex return false?");
+ Contracts.Assert(tmp); // This would only be false if the implementation of schema were buggy.
return !tmp || top != col;
}
diff --git a/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs b/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs
index 7d609454bf..302d369489 100644
--- a/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs
+++ b/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs
@@ -9,8 +9,11 @@
namespace Microsoft.ML.Runtime.Data
{
///
- /// This contains information about a column in an IDataView. It is essentially a convenience
- /// cache containing the name, column index, and column type for the column.
+ /// This contains information about a column in an . It is essentially a convenience cache
+ /// containing the name, column index, and column type for the column. The intended usage is that users of
+ /// will have a convenient method of getting the index and type without having to separately query it through the ,
+ /// since practically the first thing a consumer of a will want to do once they get a mappping is get
+ /// the type and index of the corresponding column.
///
public sealed class ColumnInfo
{
@@ -71,12 +74,20 @@ public static ColumnInfo CreateFromIndex(ISchema schema, int index)
}
///
- /// Encapsulates an ISchema plus column role mapping information. It has convenience fields for
- /// several common column roles, but can hold an arbitrary set of column infos. The convenience
- /// fields are non-null iff there is a unique column with the corresponding role. When there are
- /// no such columns or more than one such column, the field is null. The Has, HasUnique, and
- /// HasMultiple methods provide some cardinality information.
- /// Note that all columns assigned roles are guaranteed to be non-hidden in this schema.
+ /// Encapsulates an plus column role mapping information. The purpose of role mappings is to
+ /// provide information on what the intended usage is for. That is: while a given data view may have a column named
+ /// "Features", by itself that is insufficient: the trainer must be fed a role mapping that says that the role
+ /// mapping for features is filled by that "Features" column. This allows things like columns not named "Features"
+ /// to actually fill that role (as opposed to insisting on a hard coding, or having every trainer have to be
+ /// individually configured). Also, by being a one-to-many mapping, it is a way for learners that can consume
+ /// multiple features columns to consume that information.
+ ///
+ /// This class has convenience fields for several common column roles (se.g., , ), but can hold an arbitrary set of column infos. The convenience fields are non-null iff there is
+ /// a unique column with the corresponding role. When there are no such columns or more than one such column, the
+ /// field is null. The , , and methods provide
+ /// some cardinality information. Note that all columns assigned roles are guaranteed to be non-hidden in this
+ /// schema.
///
public sealed class RoleMappedSchema
{
@@ -85,18 +96,16 @@ public sealed class RoleMappedSchema
private const string GroupString = "Group";
private const string WeightString = "Weight";
private const string NameString = "Name";
- private const string IdString = "Id";
private const string FeatureContributionsString = "FeatureContributions";
public struct ColumnRole
{
- public static ColumnRole Feature { get { return new ColumnRole(FeatureString); } }
- public static ColumnRole Label { get { return new ColumnRole(LabelString); } }
- public static ColumnRole Group { get { return new ColumnRole(GroupString); } }
- public static ColumnRole Weight { get { return new ColumnRole(WeightString); } }
- public static ColumnRole Name { get { return new ColumnRole(NameString); } }
- public static ColumnRole Id { get { return new ColumnRole(IdString); } }
- public static ColumnRole FeatureContributions { get { return new ColumnRole(FeatureContributionsString); } }
+ public static ColumnRole Feature => FeatureString;
+ public static ColumnRole Label => LabelString;
+ public static ColumnRole Group => GroupString;
+ public static ColumnRole Weight => WeightString;
+ public static ColumnRole Name => NameString;
+ public static ColumnRole FeatureContributions => FeatureContributionsString;
public readonly string Value;
@@ -152,11 +161,6 @@ public static KeyValuePair CreatePair(ColumnRole role, strin
///
public readonly ColumnInfo Name;
- ///
- /// The Id column, when there is exactly one (null otherwise).
- ///
- public readonly ColumnInfo Id;
-
// Maps from role to the associated column infos.
private readonly Dictionary> _map;
@@ -194,9 +198,6 @@ private RoleMappedSchema(ISchema schema, Dictionary> map, ColumnRole rol
private static Dictionary> MapFromNames(ISchema schema, IEnumerable> roles)
{
- Contracts.AssertValue(schema, "schema");
- Contracts.AssertValue(roles, "roles");
+ Contracts.AssertValue(schema);
+ Contracts.AssertValue(roles);
var map = new Dictionary>();
foreach (var kvp in roles)
@@ -241,8 +242,8 @@ private static Dictionary> MapFromNames(ISchema schema,
private static Dictionary> MapFromNamesOpt(ISchema schema, IEnumerable> roles)
{
- Contracts.AssertValue(schema, "schema");
- Contracts.AssertValue(roles, "roles");
+ Contracts.AssertValue(schema);
+ Contracts.AssertValue(roles);
var map = new Dictionary>();
foreach (var kvp in roles)
@@ -334,6 +335,13 @@ public IEnumerable> GetColumnRoleNames(ColumnRo
}
}
+ ///
+ /// Returns the corresponding to if there is
+ /// exactly one such mapping, and otherwise throws an exception.
+ ///
+ /// The role to look up
+ /// The info corresponding to that role, assuming there was only one column
+ /// mapped to that
public ColumnInfo GetUniqueColumn(ColumnRole role)
{
var infos = GetColumns(role);
@@ -398,9 +406,9 @@ public static RoleMappedSchema CreateOpt(ISchema schema, IEnumerable
- /// Encapsulates an IDataView plus a corresponding RoleMappedSchema. Note that the schema of the
- /// RoleMappedSchema is guaranteed to be the same schema of the IDataView, that is,
- /// Data.Schema == Schema.Schema.
+ /// Encapsulates an plus a corresponding .
+ /// Note that the schema of of is
+ /// guaranteed to equal the the of .
///
public sealed class RoleMappedData
{
diff --git a/src/Microsoft.ML.Data/Commands/TrainCommand.cs b/src/Microsoft.ML.Data/Commands/TrainCommand.cs
index e55a5a3992..b5b23f9020 100644
--- a/src/Microsoft.ML.Data/Commands/TrainCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/TrainCommand.cs
@@ -468,16 +468,11 @@ private static List BacktrackPipe(IDataView dataPipe, out IDataV
Contracts.AssertValue(dataPipe);
var transforms = new List();
- while (true)
+ while (dataPipe is IDataTransform xf)
{
// REVIEW: a malicious user could construct a loop in the Source chain, that would
// cause this method to iterate forever (and throw something when the list overflows). There's
// no way to insulate from ALL malicious behavior.
-
- var xf = dataPipe as IDataTransform;
- if (xf == null)
- break;
-
transforms.Add(xf);
dataPipe = xf.Source;
Contracts.AssertValue(dataPipe);
@@ -514,11 +509,8 @@ public static bool AddNormalizerIfNeeded(IHostEnvironment env, IChannel ch, ITra
{
if (autoNorm != NormalizeOption.Yes)
{
- var nn = trainer as ITrainerEx;
DvBool isNormalized = DvBool.False;
- if (nn == null || !nn.NeedNormalization ||
- (schema.TryGetMetadata(BoolType.Instance, MetadataUtils.Kinds.IsNormalized, featCol, ref isNormalized) &&
- isNormalized.IsTrue))
+ if (trainer.NeedNormalization() != true || schema.IsNormalized(featCol))
{
ch.Info("Not adding a normalizer.");
return false;
@@ -530,20 +522,13 @@ public static bool AddNormalizerIfNeeded(IHostEnvironment env, IChannel ch, ITra
}
}
ch.Info("Automatically adding a MinMax normalization transform, use 'norm=Warn' or 'norm=No' to turn this behavior off.");
- // Quote the feature column name
- string quotedFeatureColumnName = featureColumn;
- StringBuilder sb = new StringBuilder();
- if (CmdQuoter.QuoteValue(quotedFeatureColumnName, sb))
- quotedFeatureColumnName = sb.ToString();
- var component = new SubComponent("MinMax", string.Format("col={{ name={0} source={0} }}", quotedFeatureColumnName));
- var loader = view as IDataLoader;
- if (loader != null)
- {
- view = CompositeDataLoader.Create(env, loader,
- new KeyValuePair>(null, component));
- }
+ IDataView ApplyNormalizer(IHostEnvironment innerEnv, IDataView input)
+ => NormalizeTransform.CreateMinMaxNormalizer(innerEnv, input, featureColumn);
+
+ if (view is IDataLoader loader)
+ view = CompositeDataLoader.ApplyTransform(env, loader, tag: null, creationArgs: null, ApplyNormalizer);
else
- view = component.CreateInstance(env, view);
+ view = ApplyNormalizer(env, view);
return true;
}
return false;
diff --git a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs
index 4a7106c2df..cab6467043 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs
@@ -41,7 +41,7 @@ public sealed class Arguments
public KeyValuePair>[] Transform;
}
- internal struct TransformEx
+ private struct TransformEx
{
public readonly string Tag;
public readonly string ArgsString;
@@ -78,16 +78,14 @@ private static VersionInfo GetVersionInfo()
// The composition of loader plus transforms in order.
private readonly IDataLoader _loader;
private readonly TransformEx[] _transforms;
- private readonly IDataView _view;
private readonly ITransposeDataView _tview;
- private readonly ITransposeSchema _tschema;
private readonly IHost _host;
///
/// Returns the underlying data view of the composite loader.
/// This can be used to programmatically explore the chain of transforms that's inside the composite loader.
///
- internal IDataView View { get { return _view; } }
+ internal IDataView View { get; }
///
/// Creates a loader according to the specified .
@@ -200,7 +198,7 @@ private static IDataLoader ApplyTransformsCore(IHost host, IDataLoader srcLoader
IDataLoader pipeStart;
if (composite != null)
{
- srcView = composite._view;
+ srcView = composite.View;
exes.AddRange(composite._transforms);
pipeStart = composite._loader;
}
@@ -409,9 +407,9 @@ private CompositeDataLoader(IHost host, TransformEx[] transforms)
_host = host;
_host.AssertNonEmpty(transforms);
- _view = transforms[transforms.Length - 1].Transform;
- _tview = _view as ITransposeDataView;
- _tschema = _tview == null ? new TransposerUtils.SimpleTransposeSchema(_view.Schema) : _tview.TransposeSchema;
+ View = transforms[transforms.Length - 1].Transform;
+ _tview = View as ITransposeDataView;
+ TransposeSchema = _tview?.TransposeSchema ?? new TransposerUtils.SimpleTransposeSchema(View.Schema);
var srcLoader = transforms[0].Transform.Source as IDataLoader;
@@ -561,29 +559,20 @@ private static string GenerateTag(int index)
public long? GetRowCount(bool lazy = true)
{
- return _view.GetRowCount(lazy);
+ return View.GetRowCount(lazy);
}
- public bool CanShuffle
- {
- get { return _view.CanShuffle; }
- }
+ public bool CanShuffle => View.CanShuffle;
- public ISchema Schema
- {
- get { return _view.Schema; }
- }
+ public ISchema Schema => View.Schema;
- public ITransposeSchema TransposeSchema
- {
- get { return _tschema; }
- }
+ public ITransposeSchema TransposeSchema { get; }
public IRowCursor GetRowCursor(Func predicate, IRandom rand = null)
{
_host.CheckValue(predicate, nameof(predicate));
_host.CheckValueOrNull(rand);
- return _view.GetRowCursor(predicate, rand);
+ return View.GetRowCursor(predicate, rand);
}
public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator,
@@ -591,13 +580,13 @@ public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator,
{
_host.CheckValue(predicate, nameof(predicate));
_host.CheckValueOrNull(rand);
- return _view.GetRowCursorSet(out consolidator, predicate, n, rand);
+ return View.GetRowCursorSet(out consolidator, predicate, n, rand);
}
public ISlotCursor GetSlotCursor(int col)
{
_host.CheckParam(0 <= col && col < Schema.ColumnCount, nameof(col));
- if (_tschema == null || _tschema.GetSlotType(col) == null)
+ if (TransposeSchema?.GetSlotType(col) == null)
{
throw _host.ExceptParam(nameof(col), "Bad call to GetSlotCursor on untransposable column '{0}'",
Schema.GetColumnName(col));
diff --git a/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs b/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs
index a51f502ecd..a6c0e3f490 100644
--- a/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs
@@ -245,11 +245,8 @@ private void CacheTypes(out ColumnType[] types, out ColumnType[] typesSlotNames,
{
// All meta-data is passed through in this case, so don't need the slot names type.
echoSrc[i] = true;
- DvBool b = DvBool.False;
isNormalized[i] =
- info.SrcTypes[0].ItemType.IsNumber &&
- Input.TryGetMetadata(BoolType.Instance, MetadataUtils.Kinds.IsNormalized, info.SrcIndices[0], ref b) &&
- b.IsTrue;
+ info.SrcTypes[0].ItemType.IsNumber && Input.IsNormalized(info.SrcIndices[0]);
types[i] = info.SrcTypes[0];
continue;
}
@@ -260,9 +257,7 @@ private void CacheTypes(out ColumnType[] types, out ColumnType[] typesSlotNames,
{
foreach (var srcCol in info.SrcIndices)
{
- DvBool b = DvBool.False;
- if (!Input.TryGetMetadata(BoolType.Instance, MetadataUtils.Kinds.IsNormalized, srcCol, ref b) ||
- !b.IsTrue)
+ if (!Input.IsNormalized(srcCol))
{
isNormalized[i] = false;
break;
diff --git a/src/Microsoft.ML.Transforms/NormalizeColumn.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs
similarity index 100%
rename from src/Microsoft.ML.Transforms/NormalizeColumn.cs
rename to src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs
diff --git a/src/Microsoft.ML.Transforms/NormalizeColumnDbl.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs
similarity index 100%
rename from src/Microsoft.ML.Transforms/NormalizeColumnDbl.cs
rename to src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs
diff --git a/src/Microsoft.ML.Transforms/NormalizeColumnSng.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs
similarity index 100%
rename from src/Microsoft.ML.Transforms/NormalizeColumnSng.cs
rename to src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs
diff --git a/src/Microsoft.ML.Transforms/NormalizeTransform.cs b/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs
similarity index 82%
rename from src/Microsoft.ML.Transforms/NormalizeTransform.cs
rename to src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs
index e5602918b6..318bbb44f1 100644
--- a/src/Microsoft.ML.Transforms/NormalizeTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs
@@ -197,6 +197,36 @@ private NormalizeTransform(IHost host, ArgumentsBase args, IDataView input,
SetMetadata();
}
+ ///
+ /// Potentially apply a min-max normalizer to the data's feature column, keeping all existing role
+ /// mappings except for the feature role mapping.
+ ///
+ /// The host environment to use to potentially instantiate the transform
+ /// The role-mapped data that is potentially going to be modified by this method.
+ /// The trainer to query with .
+ /// This method will not modify if the return from that is null or
+ /// false.
+ /// True if the normalizer was applied and was modified
+ public static bool CreateIfNeeded(IHostEnvironment env, ref RoleMappedData data, ITrainer trainer)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(data, nameof(data));
+ env.CheckValue(trainer, nameof(trainer));
+
+ // If this is false or null, we do not want to normalize.
+ if (trainer.NeedNormalization() != true)
+ return false;
+ // If this is true or null, we do not want to normalize.
+ if (data.Schema.FeaturesAreNormalized() != false)
+ return false;
+ var featInfo = data.Schema.Feature;
+ env.AssertValue(featInfo); // Should be defined, if FEaturesAreNormalized returned a definite value.
+
+ var view = CreateMinMaxNormalizer(env, data.Data, name: featInfo.Name);
+ data = RoleMappedData.Create(view, data.Schema.GetColumnRoleNames());
+ return true;
+ }
+
private NormalizeTransform(IHost host, ModelLoadContext ctx, IDataView input)
: base(host, ctx, input, null)
{
@@ -329,6 +359,44 @@ protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, ou
}
}
+ public static class NormalizeUtils
+ {
+ ///
+ /// Tells whether the trainer wants normalization.
+ ///
+ /// This method works via testing whether the trainer implements the optional interface
+ /// , via the Boolean property.
+ /// If does not implement that interface, then we return null
+ /// The trainer to query
+ /// Whether the trainer wants normalization
+ public static bool? NeedNormalization(this ITrainer trainer)
+ {
+ Contracts.CheckValue(trainer, nameof(trainer));
+ return (trainer as ITrainerEx)?.NeedNormalization;
+ }
+
+ ///
+ /// Returns whether the feature column in the schema is indicated to be normalized. If the features column is not
+ /// specified on the schema, then this will return null.
+ ///
+ /// The role-mapped schema to query
+ /// Returns null if does not have
+ /// defined, and otherwise returns a Boolean value as returned from
+ /// on that feature column
+ ///
+ public static bool? FeaturesAreNormalized(this RoleMappedSchema schema)
+ {
+ // REVIEW: The role mapped data has the ability to have multiple columns fill the role of features, which is
+ // useful in some trainers that are nonetheless parameteric and can therefore benefit from normalization.
+ Contracts.CheckValue(schema, nameof(schema));
+ var featInfo = schema.Feature;
+ return featInfo == null ? default(bool?) : schema.Schema.IsNormalized(featInfo.Index);
+ }
+ }
+
+ ///
+ /// This contains entry-point definitions related to .
+ ///
public static class Normalize
{
[TlcModule.EntryPoint(Name = "Transforms.MinMaxNormalizer", Desc = NormalizeTransform.MinMaxNormalizerSummary, UserName = NormalizeTransform.MinMaxNormalizerUserName, ShortName = NormalizeTransform.MinMaxNormalizerShortName)]
@@ -402,14 +470,10 @@ public static CommonOutputs.TransformOutput SupervisedBin(IHostEnvironment env,
var columnsToNormalize = new List();
foreach (var column in input.Column)
{
- int col;
- if (!schema.TryGetColumnIndex(column.Source, out col))
+ if (!schema.TryGetColumnIndex(column.Source, out int col))
throw env.ExceptUserArg(nameof(input.Column), $"Column '{column.Source}' does not exist.");
- if (!schema.TryGetMetadata(BoolType.Instance, MetadataUtils.Kinds.IsNormalized, col, ref isNormalized) ||
- isNormalized.IsFalse)
- {
+ if (!schema.IsNormalized(col))
columnsToNormalize.Add(column);
- }
}
var entryPoints = new List();
diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs
index 654735c4b6..1cb4456b68 100644
--- a/src/Microsoft.ML.FastTree/FastTree.cs
+++ b/src/Microsoft.ML.FastTree/FastTree.cs
@@ -1338,28 +1338,25 @@ private Dataset Construct(RoleMappedData examples, ref int numExamples, int maxB
IDataView data = examples.Data;
// Convert the label column, if one exists.
- var labelInfo = examples.Schema.Label;
- if (labelInfo != null)
+ var labelName = examples.Schema.Label?.Name;
+ if (labelName != null)
{
var convArgs = new LabelConvertTransform.Arguments();
- var convCol = new LabelConvertTransform.Column() { Name = labelInfo.Name, Source = labelInfo.Name };
+ var convCol = new LabelConvertTransform.Column() { Name = labelName, Source = labelName };
convArgs.Column = new LabelConvertTransform.Column[] { convCol };
data = new LabelConvertTransform(Host, convArgs, data);
- labelInfo = ColumnInfo.CreateFromName(data.Schema, convCol.Name, "converted label");
}
// Convert the group column, if one exists.
- var groupInfo = examples.Schema.Group;
- if (groupInfo != null)
+ if (examples.Schema.Group != null)
{
var convArgs = new ConvertTransform.Arguments();
var convCol = new ConvertTransform.Column
{
ResultType = DataKind.U8
};
- convCol.Name = convCol.Source = groupInfo.Name;
+ convCol.Name = convCol.Source = examples.Schema.Group.Name;
convArgs.Column = new ConvertTransform.Column[] { convCol };
data = new ConvertTransform(Host, convArgs, data);
- groupInfo = ColumnInfo.CreateFromName(data.Schema, convCol.Name, "converted group id");
}
// Since we've passed it through a few transforms, reconstitute the mapping on the