From 6c9bc0eac44b0af2e775a423516773ada2406db3 Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Wed, 27 Jun 2018 14:54:45 -0700 Subject: [PATCH 1/5] API conveniences for the Normalize transform --- src/Microsoft.ML.Core/Data/MetadataUtils.cs | 19 ++++- .../Data/RoleMappedSchema.cs | 68 ++++++++-------- .../Commands/TrainCommand.cs | 28 +++---- .../DataLoadSave/CompositeDataLoader.cs | 37 ++++----- .../Transforms/ConcatTransform.cs | 9 +-- .../Transforms}/NormalizeColumn.cs | 0 .../Transforms}/NormalizeColumnDbl.cs | 0 .../Transforms}/NormalizeColumnSng.cs | 0 .../Transforms}/NormalizeTransform.cs | 78 +++++++++++++++++-- src/Microsoft.ML.FastTree/FastTree.cs | 7 +- 10 files changed, 156 insertions(+), 90 deletions(-) rename src/{Microsoft.ML.Transforms => Microsoft.ML.Data/Transforms}/NormalizeColumn.cs (100%) rename src/{Microsoft.ML.Transforms => Microsoft.ML.Data/Transforms}/NormalizeColumnDbl.cs (100%) rename src/{Microsoft.ML.Transforms => Microsoft.ML.Data/Transforms}/NormalizeColumnSng.cs (100%) rename src/{Microsoft.ML.Transforms => Microsoft.ML.Data/Transforms}/NormalizeTransform.cs (82%) diff --git a/src/Microsoft.ML.Core/Data/MetadataUtils.cs b/src/Microsoft.ML.Core/Data/MetadataUtils.cs index 04f31d844a..a231909f66 100644 --- a/src/Microsoft.ML.Core/Data/MetadataUtils.cs +++ b/src/Microsoft.ML.Core/Data/MetadataUtils.cs @@ -335,6 +335,20 @@ public static bool HasKeyNames(this ISchema schema, int col, int keyCount) && type.ItemType.IsText; } + /// + /// Returns whether a column has the metadata set to true. + /// + /// The schema to query + /// Which column in the schema to query + /// True if and only if the column has the metadata + /// set to the scalar value + public static bool IsNormalized(this ISchema schema, int col) + { + Contracts.CheckValue(schema, nameof(schema)); + var value = default(DvBool); + return schema.TryGetMetadata(BoolType.Instance, Kinds.IsNormalized, col, ref value) && value.IsTrue; + } + /// /// Tries to get the metadata kind of the specified type for a column. /// @@ -347,6 +361,9 @@ public static bool HasKeyNames(this ISchema schema, int col, int keyCount) /// True if the metadata of the right type exists, false otherwise public static bool TryGetMetadata(this ISchema schema, PrimitiveType type, string kind, int col, ref T value) { + Contracts.CheckValue(schema, nameof(schema)); + Contracts.CheckValue(type, nameof(type)); + var metadataType = schema.GetMetadataTypeOrNull(kind, col); if (!type.Equals(metadataType)) return false; @@ -363,7 +380,7 @@ public static bool IsHidden(this ISchema schema, int col) string name = schema.GetColumnName(col); int top; bool tmp = schema.TryGetColumnIndex(name, out top); - Contracts.Assert(tmp, "Why did TryGetColumnIndex return false?"); + Contracts.Assert(tmp); // This would only be false if the implementation of schema were buggy. return !tmp || top != col; } diff --git a/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs b/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs index 7d609454bf..23e5059af2 100644 --- a/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs +++ b/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs @@ -9,8 +9,9 @@ namespace Microsoft.ML.Runtime.Data { /// - /// This contains information about a column in an IDataView. It is essentially a convenience - /// cache containing the name, column index, and column type for the column. + /// This contains information about a column in an . It is essentially a convenience cache + /// containing the name, column index, and column type for the column. The intended usage is that users of + /// to get the column index and type associated with /// public sealed class ColumnInfo { @@ -71,12 +72,20 @@ public static ColumnInfo CreateFromIndex(ISchema schema, int index) } /// - /// Encapsulates an ISchema plus column role mapping information. It has convenience fields for - /// several common column roles, but can hold an arbitrary set of column infos. The convenience - /// fields are non-null iff there is a unique column with the corresponding role. When there are - /// no such columns or more than one such column, the field is null. The Has, HasUnique, and - /// HasMultiple methods provide some cardinality information. - /// Note that all columns assigned roles are guaranteed to be non-hidden in this schema. + /// Encapsulates an plus column role mapping information. The purpose of role mappings is to + /// provide information on what the intended usage is for. That is: while a given data view may have a column named + /// "Features", by itself that is insufficient: the trainer must be fed a role mapping that says that the role + /// mapping for features is filled by that "Features" column. This allows things like columns not named "Features" + /// to actually fill that role (as opposed to insisting on a hard coding, or having every trainer have to be + /// individually configured). Also, by being a one-to-many mapping, it is a way for learners that can consume + /// multiple features columns to consume that information. + /// + /// This class has convenience fields for several common column roles (se.g., , ), but can hold an arbitrary set of column infos. The convenience fields are non-null iff there is + /// a unique column with the corresponding role. When there are no such columns or more than one such column, the + /// field is null. The , , and methods provide + /// some cardinality information. Note that all columns assigned roles are guaranteed to be non-hidden in this + /// schema. /// public sealed class RoleMappedSchema { @@ -85,18 +94,16 @@ public sealed class RoleMappedSchema private const string GroupString = "Group"; private const string WeightString = "Weight"; private const string NameString = "Name"; - private const string IdString = "Id"; private const string FeatureContributionsString = "FeatureContributions"; public struct ColumnRole { - public static ColumnRole Feature { get { return new ColumnRole(FeatureString); } } - public static ColumnRole Label { get { return new ColumnRole(LabelString); } } - public static ColumnRole Group { get { return new ColumnRole(GroupString); } } - public static ColumnRole Weight { get { return new ColumnRole(WeightString); } } - public static ColumnRole Name { get { return new ColumnRole(NameString); } } - public static ColumnRole Id { get { return new ColumnRole(IdString); } } - public static ColumnRole FeatureContributions { get { return new ColumnRole(FeatureContributionsString); } } + public static ColumnRole Feature => FeatureString; + public static ColumnRole Label => LabelString; + public static ColumnRole Group => GroupString; + public static ColumnRole Weight => WeightString; + public static ColumnRole Name => NameString; + public static ColumnRole FeatureContributions => FeatureContributionsString; public readonly string Value; @@ -152,11 +159,6 @@ public static KeyValuePair CreatePair(ColumnRole role, strin /// public readonly ColumnInfo Name; - /// - /// The Id column, when there is exactly one (null otherwise). - /// - public readonly ColumnInfo Id; - // Maps from role to the associated column infos. private readonly Dictionary> _map; @@ -194,9 +196,6 @@ private RoleMappedSchema(ISchema schema, Dictionary> map, ColumnRole rol private static Dictionary> MapFromNames(ISchema schema, IEnumerable> roles) { - Contracts.AssertValue(schema, "schema"); - Contracts.AssertValue(roles, "roles"); + Contracts.AssertValue(schema); + Contracts.AssertValue(roles); var map = new Dictionary>(); foreach (var kvp in roles) @@ -241,8 +240,8 @@ private static Dictionary> MapFromNames(ISchema schema, private static Dictionary> MapFromNamesOpt(ISchema schema, IEnumerable> roles) { - Contracts.AssertValue(schema, "schema"); - Contracts.AssertValue(roles, "roles"); + Contracts.AssertValue(schema); + Contracts.AssertValue(roles); var map = new Dictionary>(); foreach (var kvp in roles) @@ -334,6 +333,13 @@ public IEnumerable> GetColumnRoleNames(ColumnRo } } + /// + /// Returns the corresponding to if there is + /// exactly one such mapping, and otherwise throws an exception. + /// + /// The role to look up + /// The info corresponding to that role, assuming there was only one column + /// mapped to that public ColumnInfo GetUniqueColumn(ColumnRole role) { var infos = GetColumns(role); @@ -398,9 +404,9 @@ public static RoleMappedSchema CreateOpt(ISchema schema, IEnumerable - /// Encapsulates an IDataView plus a corresponding RoleMappedSchema. Note that the schema of the - /// RoleMappedSchema is guaranteed to be the same schema of the IDataView, that is, - /// Data.Schema == Schema.Schema. + /// Encapsulates an plus a corresponding . + /// Note that the schema of of is + /// guaranteed to equal the the of . /// public sealed class RoleMappedData { diff --git a/src/Microsoft.ML.Data/Commands/TrainCommand.cs b/src/Microsoft.ML.Data/Commands/TrainCommand.cs index e55a5a3992..c4500a103a 100644 --- a/src/Microsoft.ML.Data/Commands/TrainCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TrainCommand.cs @@ -514,11 +514,8 @@ public static bool AddNormalizerIfNeeded(IHostEnvironment env, IChannel ch, ITra { if (autoNorm != NormalizeOption.Yes) { - var nn = trainer as ITrainerEx; DvBool isNormalized = DvBool.False; - if (nn == null || !nn.NeedNormalization || - (schema.TryGetMetadata(BoolType.Instance, MetadataUtils.Kinds.IsNormalized, featCol, ref isNormalized) && - isNormalized.IsTrue)) + if (trainer.NeedNormalization() != true || schema.IsNormalized(featCol)) { ch.Info("Not adding a normalizer."); return false; @@ -530,20 +527,17 @@ public static bool AddNormalizerIfNeeded(IHostEnvironment env, IChannel ch, ITra } } ch.Info("Automatically adding a MinMax normalization transform, use 'norm=Warn' or 'norm=No' to turn this behavior off."); - // Quote the feature column name - string quotedFeatureColumnName = featureColumn; - StringBuilder sb = new StringBuilder(); - if (CmdQuoter.QuoteValue(quotedFeatureColumnName, sb)) - quotedFeatureColumnName = sb.ToString(); - var component = new SubComponent("MinMax", string.Format("col={{ name={0} source={0} }}", quotedFeatureColumnName)); - var loader = view as IDataLoader; - if (loader != null) - { - view = CompositeDataLoader.Create(env, loader, - new KeyValuePair>(null, component)); - } + // REVIEW: This verbose constructor should be replaced with zeahmed's enhancements once #405 is committed. + IDataView ApplyNormalizer(IHostEnvironment innerEnv, IDataView input) + => NormalizeTransform.Create(innerEnv, new NormalizeTransform.MinMaxArguments() + { + Column = new[] { new NormalizeTransform.AffineColumn { Source = featureColumn, Name = featureColumn } } + }, input); + + if (view is IDataLoader loader) + view = CompositeDataLoader.ApplyTransform(env, loader, tag: null, creationArgs: null, ApplyNormalizer); else - view = component.CreateInstance(env, view); + view = ApplyNormalizer(env, view); return true; } return false; diff --git a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs index 4a7106c2df..cab6467043 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs @@ -41,7 +41,7 @@ public sealed class Arguments public KeyValuePair>[] Transform; } - internal struct TransformEx + private struct TransformEx { public readonly string Tag; public readonly string ArgsString; @@ -78,16 +78,14 @@ private static VersionInfo GetVersionInfo() // The composition of loader plus transforms in order. private readonly IDataLoader _loader; private readonly TransformEx[] _transforms; - private readonly IDataView _view; private readonly ITransposeDataView _tview; - private readonly ITransposeSchema _tschema; private readonly IHost _host; /// /// Returns the underlying data view of the composite loader. /// This can be used to programmatically explore the chain of transforms that's inside the composite loader. /// - internal IDataView View { get { return _view; } } + internal IDataView View { get; } /// /// Creates a loader according to the specified . @@ -200,7 +198,7 @@ private static IDataLoader ApplyTransformsCore(IHost host, IDataLoader srcLoader IDataLoader pipeStart; if (composite != null) { - srcView = composite._view; + srcView = composite.View; exes.AddRange(composite._transforms); pipeStart = composite._loader; } @@ -409,9 +407,9 @@ private CompositeDataLoader(IHost host, TransformEx[] transforms) _host = host; _host.AssertNonEmpty(transforms); - _view = transforms[transforms.Length - 1].Transform; - _tview = _view as ITransposeDataView; - _tschema = _tview == null ? new TransposerUtils.SimpleTransposeSchema(_view.Schema) : _tview.TransposeSchema; + View = transforms[transforms.Length - 1].Transform; + _tview = View as ITransposeDataView; + TransposeSchema = _tview?.TransposeSchema ?? new TransposerUtils.SimpleTransposeSchema(View.Schema); var srcLoader = transforms[0].Transform.Source as IDataLoader; @@ -561,29 +559,20 @@ private static string GenerateTag(int index) public long? GetRowCount(bool lazy = true) { - return _view.GetRowCount(lazy); + return View.GetRowCount(lazy); } - public bool CanShuffle - { - get { return _view.CanShuffle; } - } + public bool CanShuffle => View.CanShuffle; - public ISchema Schema - { - get { return _view.Schema; } - } + public ISchema Schema => View.Schema; - public ITransposeSchema TransposeSchema - { - get { return _tschema; } - } + public ITransposeSchema TransposeSchema { get; } public IRowCursor GetRowCursor(Func predicate, IRandom rand = null) { _host.CheckValue(predicate, nameof(predicate)); _host.CheckValueOrNull(rand); - return _view.GetRowCursor(predicate, rand); + return View.GetRowCursor(predicate, rand); } public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, @@ -591,13 +580,13 @@ public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, { _host.CheckValue(predicate, nameof(predicate)); _host.CheckValueOrNull(rand); - return _view.GetRowCursorSet(out consolidator, predicate, n, rand); + return View.GetRowCursorSet(out consolidator, predicate, n, rand); } public ISlotCursor GetSlotCursor(int col) { _host.CheckParam(0 <= col && col < Schema.ColumnCount, nameof(col)); - if (_tschema == null || _tschema.GetSlotType(col) == null) + if (TransposeSchema?.GetSlotType(col) == null) { throw _host.ExceptParam(nameof(col), "Bad call to GetSlotCursor on untransposable column '{0}'", Schema.GetColumnName(col)); diff --git a/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs b/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs index a51f502ecd..a6c0e3f490 100644 --- a/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs @@ -245,11 +245,8 @@ private void CacheTypes(out ColumnType[] types, out ColumnType[] typesSlotNames, { // All meta-data is passed through in this case, so don't need the slot names type. echoSrc[i] = true; - DvBool b = DvBool.False; isNormalized[i] = - info.SrcTypes[0].ItemType.IsNumber && - Input.TryGetMetadata(BoolType.Instance, MetadataUtils.Kinds.IsNormalized, info.SrcIndices[0], ref b) && - b.IsTrue; + info.SrcTypes[0].ItemType.IsNumber && Input.IsNormalized(info.SrcIndices[0]); types[i] = info.SrcTypes[0]; continue; } @@ -260,9 +257,7 @@ private void CacheTypes(out ColumnType[] types, out ColumnType[] typesSlotNames, { foreach (var srcCol in info.SrcIndices) { - DvBool b = DvBool.False; - if (!Input.TryGetMetadata(BoolType.Instance, MetadataUtils.Kinds.IsNormalized, srcCol, ref b) || - !b.IsTrue) + if (!Input.IsNormalized(srcCol)) { isNormalized[i] = false; break; diff --git a/src/Microsoft.ML.Transforms/NormalizeColumn.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs similarity index 100% rename from src/Microsoft.ML.Transforms/NormalizeColumn.cs rename to src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs diff --git a/src/Microsoft.ML.Transforms/NormalizeColumnDbl.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs similarity index 100% rename from src/Microsoft.ML.Transforms/NormalizeColumnDbl.cs rename to src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs diff --git a/src/Microsoft.ML.Transforms/NormalizeColumnSng.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs similarity index 100% rename from src/Microsoft.ML.Transforms/NormalizeColumnSng.cs rename to src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs diff --git a/src/Microsoft.ML.Transforms/NormalizeTransform.cs b/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs similarity index 82% rename from src/Microsoft.ML.Transforms/NormalizeTransform.cs rename to src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs index e5602918b6..f5faff9d0e 100644 --- a/src/Microsoft.ML.Transforms/NormalizeTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs @@ -197,6 +197,39 @@ private NormalizeTransform(IHost host, ArgumentsBase args, IDataView input, SetMetadata(); } + /// + /// Potentially apply a min-max normalizer to the data's feature column, keeping all existing role + /// mappings except for the feature role mapping. + /// + /// The host environment to use to potentially instantiate the transform + /// The role-mapped data that is potentially going to be modified by this method. + /// The trainer to query with . + /// This method will not modify if the return from that is null or + /// false. + /// True if the normalizer was applied and was modified + public static bool CreateIfNeeded(IHostEnvironment env, ref RoleMappedData data, ITrainer trainer) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(data, nameof(data)); + env.CheckValue(trainer, nameof(trainer)); + + // If this is false or null, we do not want to normalize. + if (trainer.NeedNormalization() != true) + return false; + // If this is true or null, we do not want to normalize. + if (data.Schema.FeaturesAreNormalized() != false) + return false; + var featInfo = data.Schema.Feature; + env.AssertValue(featInfo); // Should be defined, if FEaturesAreNormalized returned a definite value. + + var view = Create(env, new MinMaxArguments() + { + Column = new[] { new AffineColumn() { Name = featInfo.Name, Source = featInfo.Name } } + }, data.Data); + data = RoleMappedData.Create(view, data.Schema.GetColumnRoleNames()); + return true; + } + private NormalizeTransform(IHost host, ModelLoadContext ctx, IDataView input) : base(host, ctx, input, null) { @@ -329,6 +362,43 @@ protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, ou } } + public static class NormalizeUtils + { + /// + /// Tells whether the trainer wants normalization. + /// + /// This method works via testing whether the trainer implements the optional interface + /// , via the Boolean property. + /// If does not implement that interface, then we return null + /// The trainer to query + /// Whether the trainer wants normalization + public static bool? NeedNormalization(this ITrainer trainer) + { + Contracts.CheckValue(trainer, nameof(trainer)); + return (trainer as ITrainerEx)?.NeedNormalization; + } + + /// + /// Returns whether the feature column in the schema is indicated to be normalized. If the features column is not + /// specified on the schema, then this will return null. + /// + /// The role-mapped schema to query + /// Returns null if does not have + /// defined, and otherwise returns a Boolean value as returned from + /// on that feature column + /// + public static bool? FeaturesAreNormalized(this RoleMappedSchema schema) + { + // REVIEW: The role mapped data has the ability to have multiple columns fill the role of features, which is + // useful in some trainers that are nonetheless parameteric and can therefore benefit from normalization. + var featInfo = schema.Feature; + return featInfo == null ? default(bool?) : schema.Schema.IsNormalized(featInfo.Index); + } + } + + /// + /// This contains entry-point definitions related to . + /// public static class Normalize { [TlcModule.EntryPoint(Name = "Transforms.MinMaxNormalizer", Desc = NormalizeTransform.MinMaxNormalizerSummary, UserName = NormalizeTransform.MinMaxNormalizerUserName, ShortName = NormalizeTransform.MinMaxNormalizerShortName)] @@ -402,14 +472,10 @@ public static CommonOutputs.TransformOutput SupervisedBin(IHostEnvironment env, var columnsToNormalize = new List(); foreach (var column in input.Column) { - int col; - if (!schema.TryGetColumnIndex(column.Source, out col)) + if (!schema.TryGetColumnIndex(column.Source, out int col)) throw env.ExceptUserArg(nameof(input.Column), $"Column '{column.Source}' does not exist."); - if (!schema.TryGetMetadata(BoolType.Instance, MetadataUtils.Kinds.IsNormalized, col, ref isNormalized) || - isNormalized.IsFalse) - { + if (!schema.IsNormalized(col)) columnsToNormalize.Add(column); - } } var entryPoints = new List(); diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 654735c4b6..09fa70a8ae 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -1338,14 +1338,13 @@ private Dataset Construct(RoleMappedData examples, ref int numExamples, int maxB IDataView data = examples.Data; // Convert the label column, if one exists. - var labelInfo = examples.Schema.Label; - if (labelInfo != null) + var labelName = examples.Schema.Label?.Name; + if (labelName != null) { var convArgs = new LabelConvertTransform.Arguments(); - var convCol = new LabelConvertTransform.Column() { Name = labelInfo.Name, Source = labelInfo.Name }; + var convCol = new LabelConvertTransform.Column() { Name = labelName, Source = labelName }; convArgs.Column = new LabelConvertTransform.Column[] { convCol }; data = new LabelConvertTransform(Host, convArgs, data); - labelInfo = ColumnInfo.CreateFromName(data.Schema, convCol.Name, "converted label"); } // Convert the group column, if one exists. var groupInfo = examples.Schema.Group; From f446afa47ef25f7053212923e66a6c841e57f4e8 Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Thu, 28 Jun 2018 08:23:29 -0700 Subject: [PATCH 2/5] Minor changes --- src/Microsoft.ML.Core/Data/RoleMappedSchema.cs | 3 ++- src/Microsoft.ML.Data/Commands/TrainCommand.cs | 7 +------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs b/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs index 23e5059af2..473ad7584a 100644 --- a/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs +++ b/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs @@ -11,7 +11,8 @@ namespace Microsoft.ML.Runtime.Data /// /// This contains information about a column in an . It is essentially a convenience cache /// containing the name, column index, and column type for the column. The intended usage is that users of - /// to get the column index and type associated with + /// will have a convenient method of getting the index and type without having to separately query it through the , + /// since practically the first thing a consumer of a will want to do once they get a mappping is /// public sealed class ColumnInfo { diff --git a/src/Microsoft.ML.Data/Commands/TrainCommand.cs b/src/Microsoft.ML.Data/Commands/TrainCommand.cs index c4500a103a..9ac5ce6a5f 100644 --- a/src/Microsoft.ML.Data/Commands/TrainCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TrainCommand.cs @@ -468,16 +468,11 @@ private static List BacktrackPipe(IDataView dataPipe, out IDataV Contracts.AssertValue(dataPipe); var transforms = new List(); - while (true) + while (dataPipe is IDataTransform xf) { // REVIEW: a malicious user could construct a loop in the Source chain, that would // cause this method to iterate forever (and throw something when the list overflows). There's // no way to insulate from ALL malicious behavior. - - var xf = dataPipe as IDataTransform; - if (xf == null) - break; - transforms.Add(xf); dataPipe = xf.Source; Contracts.AssertValue(dataPipe); From 7fb15aecb48482566a295608cae6c2d1ccf4189c Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Thu, 28 Jun 2018 13:58:19 -0700 Subject: [PATCH 3/5] Ivan comments --- src/Microsoft.ML.Data/Commands/TrainCommand.cs | 1 - src/Microsoft.ML.FastTree/FastTree.cs | 6 ++---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/Microsoft.ML.Data/Commands/TrainCommand.cs b/src/Microsoft.ML.Data/Commands/TrainCommand.cs index 9ac5ce6a5f..72659dfddc 100644 --- a/src/Microsoft.ML.Data/Commands/TrainCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TrainCommand.cs @@ -522,7 +522,6 @@ public static bool AddNormalizerIfNeeded(IHostEnvironment env, IChannel ch, ITra } } ch.Info("Automatically adding a MinMax normalization transform, use 'norm=Warn' or 'norm=No' to turn this behavior off."); - // REVIEW: This verbose constructor should be replaced with zeahmed's enhancements once #405 is committed. IDataView ApplyNormalizer(IHostEnvironment innerEnv, IDataView input) => NormalizeTransform.Create(innerEnv, new NormalizeTransform.MinMaxArguments() { diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 09fa70a8ae..1cb4456b68 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -1347,18 +1347,16 @@ private Dataset Construct(RoleMappedData examples, ref int numExamples, int maxB data = new LabelConvertTransform(Host, convArgs, data); } // Convert the group column, if one exists. - var groupInfo = examples.Schema.Group; - if (groupInfo != null) + if (examples.Schema.Group != null) { var convArgs = new ConvertTransform.Arguments(); var convCol = new ConvertTransform.Column { ResultType = DataKind.U8 }; - convCol.Name = convCol.Source = groupInfo.Name; + convCol.Name = convCol.Source = examples.Schema.Group.Name; convArgs.Column = new ConvertTransform.Column[] { convCol }; data = new ConvertTransform(Host, convArgs, data); - groupInfo = ColumnInfo.CreateFromName(data.Schema, convCol.Name, "converted group id"); } // Since we've passed it through a few transforms, reconstitute the mapping on the From 3f1e4540b77124a6234b5bbe7da2795fb741e42d Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Mon, 2 Jul 2018 14:22:57 -0700 Subject: [PATCH 4/5] Use normalize convenience constructors. --- src/Microsoft.ML.Data/Commands/TrainCommand.cs | 5 +---- src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs | 5 +---- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/src/Microsoft.ML.Data/Commands/TrainCommand.cs b/src/Microsoft.ML.Data/Commands/TrainCommand.cs index 72659dfddc..b5b23f9020 100644 --- a/src/Microsoft.ML.Data/Commands/TrainCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TrainCommand.cs @@ -523,10 +523,7 @@ public static bool AddNormalizerIfNeeded(IHostEnvironment env, IChannel ch, ITra } ch.Info("Automatically adding a MinMax normalization transform, use 'norm=Warn' or 'norm=No' to turn this behavior off."); IDataView ApplyNormalizer(IHostEnvironment innerEnv, IDataView input) - => NormalizeTransform.Create(innerEnv, new NormalizeTransform.MinMaxArguments() - { - Column = new[] { new NormalizeTransform.AffineColumn { Source = featureColumn, Name = featureColumn } } - }, input); + => NormalizeTransform.CreateMinMaxNormalizer(innerEnv, input, featureColumn); if (view is IDataLoader loader) view = CompositeDataLoader.ApplyTransform(env, loader, tag: null, creationArgs: null, ApplyNormalizer); diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs b/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs index f5faff9d0e..013bc93ecb 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs @@ -222,10 +222,7 @@ public static bool CreateIfNeeded(IHostEnvironment env, ref RoleMappedData data, var featInfo = data.Schema.Feature; env.AssertValue(featInfo); // Should be defined, if FEaturesAreNormalized returned a definite value. - var view = Create(env, new MinMaxArguments() - { - Column = new[] { new AffineColumn() { Name = featInfo.Name, Source = featInfo.Name } } - }, data.Data); + var view = CreateMinMaxNormalizer(env, data.Data, name: featInfo.Name); data = RoleMappedData.Create(view, data.Schema.GetColumnRoleNames()); return true; } From 5153afb6034ad7d2722cb7b98fc6240b9f22b5f1 Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Mon, 2 Jul 2018 14:32:18 -0700 Subject: [PATCH 5/5] PR comments. --- src/Microsoft.ML.Core/Data/MetadataUtils.cs | 2 ++ src/Microsoft.ML.Core/Data/RoleMappedSchema.cs | 3 ++- src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs | 1 + 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/Microsoft.ML.Core/Data/MetadataUtils.cs b/src/Microsoft.ML.Core/Data/MetadataUtils.cs index a231909f66..ca13d03ab3 100644 --- a/src/Microsoft.ML.Core/Data/MetadataUtils.cs +++ b/src/Microsoft.ML.Core/Data/MetadataUtils.cs @@ -337,6 +337,8 @@ public static bool HasKeyNames(this ISchema schema, int col, int keyCount) /// /// Returns whether a column has the metadata set to true. + /// That metadata should be set when the data has undergone transforms that would render it + /// "normalized." /// /// The schema to query /// Which column in the schema to query diff --git a/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs b/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs index 473ad7584a..302d369489 100644 --- a/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs +++ b/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs @@ -12,7 +12,8 @@ namespace Microsoft.ML.Runtime.Data /// This contains information about a column in an . It is essentially a convenience cache /// containing the name, column index, and column type for the column. The intended usage is that users of /// will have a convenient method of getting the index and type without having to separately query it through the , - /// since practically the first thing a consumer of a will want to do once they get a mappping is + /// since practically the first thing a consumer of a will want to do once they get a mappping is get + /// the type and index of the corresponding column. /// public sealed class ColumnInfo { diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs b/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs index 013bc93ecb..318bbb44f1 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeTransform.cs @@ -388,6 +388,7 @@ public static class NormalizeUtils { // REVIEW: The role mapped data has the ability to have multiple columns fill the role of features, which is // useful in some trainers that are nonetheless parameteric and can therefore benefit from normalization. + Contracts.CheckValue(schema, nameof(schema)); var featInfo = schema.Feature; return featInfo == null ? default(bool?) : schema.Schema.IsNormalized(featInfo.Index); }