From e0c85cde4f52df86a45fb9395a4d97648ee65a74 Mon Sep 17 00:00:00 2001 From: Artidoro Pagnoni Date: Fri, 5 Oct 2018 10:32:34 -0700 Subject: [PATCH 01/13] started conversion work --- .../NAIndicatorTransform.cs | 159 +++++++++++++----- 1 file changed, 120 insertions(+), 39 deletions(-) diff --git a/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs b/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs index 044751f459..a8b3686ecd 100644 --- a/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs +++ b/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs @@ -4,25 +4,34 @@ using System; using System.Collections.Generic; +using System.IO; +using System.Linq; using System.Text; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.Conversion; +using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; -[assembly: LoadableClass(typeof(NAIndicatorTransform), typeof(NAIndicatorTransform.Arguments), typeof(SignatureDataTransform), - NAIndicatorTransform.FriendlyName, "NAIndicatorTransform", "NAIndicator", NAIndicatorTransform.ShortName, DocName = "transform/NAHandle.md")] +[assembly: LoadableClass(NAIndicatorTransform.Summary, typeof(IDataTransform), typeof(NAIndicatorTransform), typeof(NAIndicatorTransform.Arguments), typeof(SignatureDataTransform), + NAIndicatorTransform.FriendlyName, NAIndicatorTransform.LoadName, "NAIndicator", NAIndicatorTransform.ShortName, DocName = "transform/NAHandle.md")] -[assembly: LoadableClass(typeof(NAIndicatorTransform), null, typeof(SignatureLoadDataTransform), - NAIndicatorTransform.FriendlyName, NAIndicatorTransform.LoaderSignature)] +[assembly: LoadableClass(NAIndicatorTransform.Summary, typeof(IDataTransform), typeof(NAIndicatorTransform), null, typeof(SignatureLoadDataTransform), + NAIndicatorTransform.FriendlyName, NAIndicatorTransform.LoadName)] + +[assembly: LoadableClass(NAIndicatorTransform.Summary, typeof(NAIndicatorTransform), null, typeof(SignatureLoadModel), + NAIndicatorTransform.FriendlyName, NAIndicatorTransform.LoadName)] + +[assembly: LoadableClass(typeof(IRowMapper), typeof(NAIndicatorTransform), null, typeof(SignatureLoadRowMapper), + NAIndicatorTransform.FriendlyName, NAIndicatorTransform.LoadName)] namespace Microsoft.ML.Runtime.Data { /// - public sealed class NAIndicatorTransform : OneToOneTransformBase + public sealed class NAIndicatorTransform : OneToOneTransformerBase { public sealed class Column : OneToOneColumn { @@ -49,7 +58,7 @@ public sealed class Arguments : TransformInputBase public Column[] Column; } - public const string LoaderSignature = "NaIndicatorTransform"; + public const string LoadName = "NaIndicatorTransform"; private static VersionInfo GetVersionInfo() { @@ -59,7 +68,7 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature, + loaderSignature: LoadName, loaderAssemblyName: typeof(NAIndicatorTransform).Assembly.FullName); } @@ -68,77 +77,145 @@ private static VersionInfo GetVersionInfo() internal const string FriendlyName = "NA Indicator Transform"; internal const string ShortName = "NAInd"; - private static string TestType(ColumnType type) + internal static string TestType(ColumnType type) { // Item type must have an NA value. We'll get the predicate again later when we're ready to use it. Delegate del; if (Conversions.Instance.TryGetIsNAPredicate(type.ItemType, out del)) return null; return string.Format("Type '{0}' is not supported by {1} since it doesn't have an NA value", - type, LoaderSignature); + type, LoadName); + } + + // TODO: Why do we even need an object for this? maybe because we want to deal with array of these things + public class ColumnInfo + { + public readonly string Input; + public readonly string Output; + + /// + /// Describes how the transformer handles one column pair. + /// + /// Name of input column. + /// Name of output column. + public ColumnInfo(string input, string output) + { + Input = input; + Output = output; + } + } + + private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) + { + Contracts.CheckValue(columns, nameof(columns)); + return columns.Select(x => (x.Input, x.Output)).ToArray(); } private const string RegistrationName = "NaIndicator"; // The output column types, parallel to Infos. - private readonly ColumnType[] _types; + private readonly ColumnType[] _replaceTypes; + + protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) + { + var type = inputSchema.GetColumnType(srcCol); + string reason = TestType(type); + if (reason != null) + throw Host.ExceptParam(nameof(inputSchema), reason); + } /// /// Convenience constructor for public facing API. /// /// Host Environment. /// Input . This is the output from previous transform or loader. - /// Name of the output column. - /// Name of the column to be transformed. If this is null '' will be used. - public NAIndicatorTransform(IHostEnvironment env, IDataView input, string name, string source = null) - : this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } } }, input) + /// + public NAIndicatorTransform(IHostEnvironment env, IDataView input, params ColumnInfo[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(NAReplaceTransform)), GetColumnPairs(columns)) { + // Check that all the input columns are present and correct. + for (int i = 0; i < ColumnPairs.Length; i++) + { + if (!input.Schema.TryGetColumnIndex(ColumnPairs[i].input, out int srcCol)) + throw Host.ExceptSchemaMismatch(nameof(input), "input", ColumnPairs[i].input); + CheckInputColumn(input.Schema, i, srcCol); + } } - /// - /// Public constructor corresponding to SignatureDataTransform. - /// - public NAIndicatorTransform(IHostEnvironment env, Arguments args, IDataView input) - : base(env, RegistrationName, Contracts.CheckRef(args, nameof(args)).Column, - input, TestType) + private NAIndicatorTransform(IHost host, ModelLoadContext ctx) + : base(host, ctx) { - Host.AssertNonEmpty(Infos); - Host.Assert(Infos.Length == Utils.Size(args.Column)); - - _types = GetTypesAndMetadata(); + Host.AssertValue(ctx); } - private NAIndicatorTransform(IHost host, ModelLoadContext ctx, IDataView input) - : base(host, ctx, input, TestType) + // Factory method for SignatureDataTransform. + public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) { - Host.AssertValue(ctx); - Host.AssertNonEmpty(Infos); + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(args, nameof(args)); + env.CheckValue(input, nameof(input)); - // *** Binary format *** - // + env.CheckValue(args.Column, nameof(args.Column)); + var cols = new ColumnInfo[args.Column.Length]; + for (int i = 0; i < cols.Length; i++) + { + var item = args.Column[i]; - _types = GetTypesAndMetadata(); + cols[i] = new ColumnInfo(item.Source, item.Name); + }; + return new NAIndicatorTransform(env, input, cols).MakeDataTransform(input); } - public static NAIndicatorTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + public static NAIndicatorTransform Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); - var h = env.Register(RegistrationName); - h.CheckValue(ctx, nameof(ctx)); - h.CheckValue(input, nameof(input)); + var host = env.Register(RegistrationName); + + host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return h.Apply("Loading Model", ch => new NAIndicatorTransform(h, ctx, input)); + + return new NAIndicatorTransform(host, ctx); + } + + public static IDataTransform Create(IHostEnvironment env, IDataView input, params ColumnInfo[] columns) + => new NAIndicatorTransform(env, input, columns).MakeDataTransform(input); + + // Factory method for SignatureLoadDataTransform. + public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + => Create(env, ctx).MakeDataTransform(input); + + // Factory method for SignatureLoadRowMapper. + public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); + + private void WriteTypeAndValue(Stream stream, BinarySaver saver, ColumnType type, T rep) + { + Host.AssertValue(stream); + Host.AssertValue(saver); + Host.Assert(type.RawType == typeof(T) || type.ItemType.RawType == typeof(T)); + + if (!saver.TryWriteTypeAndValue(stream, type, ref rep, out int bytesWritten)) + throw Host.Except("We do not know how to serialize terms of type '{0}'", type); } public override void Save(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); - // *** Binary format *** - // - SaveBase(ctx); + SaveColumns(ctx); + var saver = new BinarySaver(Host, new BinarySaver.Arguments()); + for (int iinfo = 0; iinfo < _replaceTypes.Length; iinfo++) + { + var repType = _replaceTypes[iinfo].ItemType; + object[] args = new object[] { ctx.Writer.BaseStream, saver, repType, ???? }; + Action func = WriteTypeAndValue; + Host.Assert(repValue.GetType() == _replaceTypes[iinfo].RawType || repValue.GetType() == _replaceTypes[iinfo].ItemType.RawType); + var meth = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(repValue.GetType()); + meth.Invoke(this, args); + } } private ColumnType[] GetTypesAndMetadata() @@ -385,5 +462,9 @@ private void FillValues(int srcLength, ref VBuffer dst, List indices, dst = new VBuffer(srcLength, dstValues, dstIndices); } } + + protected override IRowMapper MakeRowMapper(ISchema schema) + { + } } } From 9e23955303a106278c9e1a468f09f9a11a05ec02 Mon Sep 17 00:00:00 2001 From: Artidoro Pagnoni Date: Fri, 5 Oct 2018 10:33:48 -0700 Subject: [PATCH 02/13] started conversion work --- .../NAIndicatorTransform.cs | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs b/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs index a8b3686ecd..e23757312e 100644 --- a/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs +++ b/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs @@ -206,15 +206,18 @@ public override void Save(ModelSaveContext ctx) ctx.SetVersionInfo(GetVersionInfo()); SaveColumns(ctx); - var saver = new BinarySaver(Host, new BinarySaver.Arguments()); - for (int iinfo = 0; iinfo < _replaceTypes.Length; iinfo++) + Contracts.AssertValue(ctx); + + // *** Binary format *** + // int: number of added columns + // for each added column + // int: id of output column name + // int: id of input column name + ctx.Writer.Write(.Length); + foreach (var info in Infos) { - var repType = _replaceTypes[iinfo].ItemType; - object[] args = new object[] { ctx.Writer.BaseStream, saver, repType, ???? }; - Action func = WriteTypeAndValue; - Host.Assert(repValue.GetType() == _replaceTypes[iinfo].RawType || repValue.GetType() == _replaceTypes[iinfo].ItemType.RawType); - var meth = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(repValue.GetType()); - meth.Invoke(this, args); + ctx.SaveNonEmptyString(info.Name); + ctx.SaveNonEmptyString(Input.GetColumnName(info.Source)); } } From e0e7d3baf93b6ca06c79026a93346f25fc4ae6bf Mon Sep 17 00:00:00 2001 From: Artidoro Pagnoni Date: Wed, 10 Oct 2018 10:47:03 -0700 Subject: [PATCH 03/13] wrote a first version of the conversion, need to debug it and make sure tests pass --- .../NAHandleTransform.cs | 2 +- src/Microsoft.ML.Transforms/NAHandling.cs | 2 +- .../NAIndicatorTransform.cs | 424 +++++++++++++++--- .../Transformers/NAIndicatorTests.cs | 130 ++++++ 4 files changed, 486 insertions(+), 72 deletions(-) create mode 100644 test/Microsoft.ML.Tests/Transformers/NAIndicatorTests.cs diff --git a/src/Microsoft.ML.Transforms/NAHandleTransform.cs b/src/Microsoft.ML.Transforms/NAHandleTransform.cs index bf3ab5f41a..8645d72647 100644 --- a/src/Microsoft.ML.Transforms/NAHandleTransform.cs +++ b/src/Microsoft.ML.Transforms/NAHandleTransform.cs @@ -212,7 +212,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV // Create the indicator columns. if (naIndicatorCols.Count > 0) - output = new NAIndicatorTransform(h, new NAIndicatorTransform.Arguments() { Column = naIndicatorCols.ToArray() }, input); + output = NAIndicatorTransform.Create(h, new NAIndicatorTransform.Arguments() { Column = naIndicatorCols.ToArray() }, input); // Convert the indicator columns to the correct type so that they can be concatenated to the NAReplace outputs. if (naConvCols.Count > 0) diff --git a/src/Microsoft.ML.Transforms/NAHandling.cs b/src/Microsoft.ML.Transforms/NAHandling.cs index 8f8fdaabb9..8a6635ef77 100644 --- a/src/Microsoft.ML.Transforms/NAHandling.cs +++ b/src/Microsoft.ML.Transforms/NAHandling.cs @@ -71,7 +71,7 @@ public static CommonOutputs.TransformOutput Handle(IHostEnvironment env, NAHandl public static CommonOutputs.TransformOutput Indicator(IHostEnvironment env, NAIndicatorTransform.Arguments input) { var h = EntryPointUtils.CheckArgsAndCreateHost(env, "NAIndicator", input); - var xf = new NAIndicatorTransform(h, input, input.Data); + var xf = NAIndicatorTransform.Create(h, input, input.Data); return new CommonOutputs.TransformOutput() { Model = new TransformModel(h, xf, input.Data), diff --git a/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs b/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs index e23757312e..1b101f70ad 100644 --- a/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs +++ b/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs @@ -4,17 +4,19 @@ using System; using System.Collections.Generic; -using System.IO; using System.Linq; using System.Text; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.Conversion; -using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Runtime.Model.Onnx; +using Microsoft.ML.StaticPipe; +using Microsoft.ML.StaticPipe.Runtime; [assembly: LoadableClass(NAIndicatorTransform.Summary, typeof(IDataTransform), typeof(NAIndicatorTransform), typeof(NAIndicatorTransform.Arguments), typeof(SignatureDataTransform), NAIndicatorTransform.FriendlyName, NAIndicatorTransform.LoadName, "NAIndicator", NAIndicatorTransform.ShortName, DocName = "transform/NAHandle.md")] @@ -105,17 +107,19 @@ public ColumnInfo(string input, string output) } } + private const string RegistrationName = nameof(NAIndicatorTransform); + + // The input column types + private ColumnType[] _inputTypes; + // The output column types + private ColumnType[] _outputTypes; + private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) { Contracts.CheckValue(columns, nameof(columns)); return columns.Select(x => (x.Input, x.Output)).ToArray(); } - private const string RegistrationName = "NaIndicator"; - - // The output column types, parallel to Infos. - private readonly ColumnType[] _replaceTypes; - protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) { var type = inputSchema.GetColumnType(srcCol); @@ -124,14 +128,37 @@ protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCo throw Host.ExceptParam(nameof(inputSchema), reason); } + private (ColumnType[], ColumnType[]) GetTypes(ISchema schema) + { + var inputTypes = new ColumnType[ColumnPairs.Length]; + var outputTypes = new ColumnType[ColumnPairs.Length]; + for (int i = 0; i < ColumnPairs.Length; i++) + { + schema.TryGetColumnIndex(ColumnPairs[i].input, out int colSrc); + var type = schema.GetColumnType(colSrc); + + if (!type.IsVector) + { + inputTypes[i] = type.AsPrimitive; + outputTypes[i] = BoolType.Instance; + } + else + { + inputTypes[i] = new VectorType(type.ItemType.AsPrimitive, type.AsVector); + outputTypes[i] = new VectorType(BoolType.Instance, type.AsVector); + } + } + return (inputTypes, outputTypes); + } + /// /// Convenience constructor for public facing API. /// /// Host Environment. /// Input . This is the output from previous transform or loader. - /// + /// TODO public NAIndicatorTransform(IHostEnvironment env, IDataView input, params ColumnInfo[] columns) - : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(NAReplaceTransform)), GetColumnPairs(columns)) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(NAIndicatorTransform)), GetColumnPairs(columns)) { // Check that all the input columns are present and correct. for (int i = 0; i < ColumnPairs.Length; i++) @@ -140,12 +167,15 @@ public NAIndicatorTransform(IHostEnvironment env, IDataView input, params Column throw Host.ExceptSchemaMismatch(nameof(input), "input", ColumnPairs[i].input); CheckInputColumn(input.Schema, i, srcCol); } + (_inputTypes, _outputTypes) = GetTypes(input.Schema); } private NAIndicatorTransform(IHost host, ModelLoadContext ctx) : base(host, ctx) { Host.AssertValue(ctx); + _outputTypes = null; + _inputTypes = null; } // Factory method for SignatureDataTransform. @@ -188,14 +218,18 @@ public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) => Create(env, ctx).MakeRowMapper(inputSchema); - private void WriteTypeAndValue(Stream stream, BinarySaver saver, ColumnType type, T rep) + /// + /// Returns the isNA predicate for the respective type. + /// + private Delegate GetIsNADelegate(ColumnType type) { - Host.AssertValue(stream); - Host.AssertValue(saver); - Host.Assert(type.RawType == typeof(T) || type.ItemType.RawType == typeof(T)); + Func func = GetIsNADelegate; + return Utils.MarshalInvoke(func, type.ItemType.RawType, type); + } - if (!saver.TryWriteTypeAndValue(stream, type, ref rep, out int bytesWritten)) - throw Host.Except("We do not know how to serialize terms of type '{0}'", type); + private Delegate GetIsNADelegate(ColumnType type) + { + return Conversions.Instance.GetIsNAPredicate(type.ItemType); } public override void Save(ModelSaveContext ctx) @@ -206,58 +240,25 @@ public override void Save(ModelSaveContext ctx) ctx.SetVersionInfo(GetVersionInfo()); SaveColumns(ctx); - Contracts.AssertValue(ctx); - - // *** Binary format *** - // int: number of added columns - // for each added column - // int: id of output column name - // int: id of input column name - ctx.Writer.Write(.Length); - foreach (var info in Infos) - { - ctx.SaveNonEmptyString(info.Name); - ctx.SaveNonEmptyString(Input.GetColumnName(info.Source)); - } - } - - private ColumnType[] GetTypesAndMetadata() - { - var md = Metadata; - var types = new ColumnType[Infos.Length]; - for (int iinfo = 0; iinfo < Infos.Length; iinfo++) - { - var type = Infos[iinfo].TypeSrc; - - if (!type.IsVector) - types[iinfo] = BoolType.Instance; - else - types[iinfo] = new VectorType(BoolType.Instance, type.AsVector); - // Pass through slot name metadata. - using (var bldr = md.BuildMetadata(iinfo, Source.Schema, Infos[iinfo].Source, MetadataUtils.Kinds.SlotNames)) - { - // Output is normalized. - bldr.AddPrimitive(MetadataUtils.Kinds.IsNormalized, BoolType.Instance, true); - } - } - md.Seal(); - return types; } - protected override ColumnType GetColumnTypeCore(int iinfo) + private ColumnType GetColumnTypeCore(int iinfo) { - Host.Assert(0 <= iinfo & iinfo < Infos.Length); - return _types[iinfo]; + Host.Assert(0 <= iinfo & iinfo < ColumnPairs.Length); + return _outputTypes[iinfo]; } - protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) + private Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) { Host.AssertValueOrNull(ch); Host.AssertValue(input); - Host.Assert(0 <= iinfo && iinfo < Infos.Length); + Host.Assert(0 <= iinfo && iinfo < ColumnPairs.Length); disposer = null; + // TODO: is there a better way then checking every time in the getgetter? + if (_outputTypes == null || _inputTypes == null) + (_inputTypes, _outputTypes) = GetTypes(input.Schema); - if (!Infos[iinfo].TypeSrc.IsVector) + if (!_outputTypes[iinfo].IsVector) return ComposeGetterOne(input, iinfo); return ComposeGetterVec(input, iinfo); } @@ -268,16 +269,16 @@ protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, ou private ValueGetter ComposeGetterOne(IRow input, int iinfo) { Func> func = ComposeGetterOne; - return Utils.MarshalInvoke(func, Infos[iinfo].TypeSrc.RawType, input, iinfo); + return Utils.MarshalInvoke(func, _inputTypes[iinfo].RawType, input, iinfo); } - /// - /// Tests if a value is NA for scalars. - /// private ValueGetter ComposeGetterOne(IRow input, int iinfo) { - var getSrc = GetSrcGetter(input, iinfo); - var isNA = Conversions.Instance.GetIsNAPredicate(input.Schema.GetColumnType(Infos[iinfo].Source)); + Host.Assert(0 <= iinfo & iinfo < ColumnPairs.Length); + Host.Assert(input.IsColumnActive(iinfo)); + + var getSrc = input.GetGetter(iinfo); + var isNA = Conversions.Instance.GetIsNAPredicate(_inputTypes[iinfo]); T src = default(T); return (ref bool dst) => @@ -293,16 +294,16 @@ private ValueGetter ComposeGetterOne(IRow input, int iinfo) private ValueGetter> ComposeGetterVec(IRow input, int iinfo) { Func>> func = ComposeGetterVec; - return Utils.MarshalInvoke(func, Infos[iinfo].TypeSrc.ItemType.RawType, input, iinfo); + return Utils.MarshalInvoke(func, _inputTypes[iinfo].ItemType.RawType, input, iinfo); } - /// - /// Tests if a value is NA for vectors. - /// private ValueGetter> ComposeGetterVec(IRow input, int iinfo) { - var getSrc = GetSrcGetter>(input, iinfo); - var isNA = Conversions.Instance.GetIsNAPredicate(input.Schema.GetColumnType(Infos[iinfo].Source).ItemType); + Host.Assert(0 <= iinfo & iinfo < ColumnPairs.Length); + Host.Assert(input.IsColumnActive(iinfo)); + + var getSrc = input.GetGetter>(iinfo); + var isNA = Conversions.Instance.GetIsNAPredicate(_inputTypes[iinfo]); var val = default(T); bool defaultIsNA = isNA(ref val); var src = default(VBuffer); @@ -467,7 +468,290 @@ private void FillValues(int srcLength, ref VBuffer dst, List indices, } protected override IRowMapper MakeRowMapper(ISchema schema) + => new Mapper(this, schema); + + private sealed class Mapper : MapperBase + { + private sealed class ColInfo + { + public readonly string Name; + public readonly string Source; + public readonly ColumnType TypeSrc; + + public ColInfo(string name, string source, ColumnType type) + { + Name = name; + Source = source; + TypeSrc = type; + } + } + + private readonly NAIndicatorTransform _parent; + private readonly ColInfo[] _infos; + private readonly ColumnType[] _types; + // The isNA delegates, parallel to Infos. + private readonly Delegate[] _isNAs; + + public Mapper(NAIndicatorTransform parent, ISchema inputSchema) + : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) + { + _parent = parent; + _infos = CreateInfos(inputSchema); + _types = new ColumnType[_parent.ColumnPairs.Length]; + _isNAs = new Delegate[_parent.ColumnPairs.Length]; + for (int i = 0; i < _parent.ColumnPairs.Length; i++) + { + var type = _infos[i].TypeSrc; + if (!type.IsVector) + _types[i] = BoolType.Instance; + else + _types[i] = new VectorType(BoolType.Instance, type.AsVector); + + _types[i] = type; + _isNAs[i] = _parent.GetIsNADelegate(type); + } + } + + private ColInfo[] CreateInfos(ISchema inputSchema) + { + Host.AssertValue(inputSchema); + var infos = new ColInfo[_parent.ColumnPairs.Length]; + for (int i = 0; i < _parent.ColumnPairs.Length; i++) + { + if (!inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colSrc)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input); + _parent.CheckInputColumn(inputSchema, i, colSrc); + var type = inputSchema.GetColumnType(colSrc); + infos[i] = new ColInfo(_parent.ColumnPairs[i].output, _parent.ColumnPairs[i].input, type); + } + return infos; + } + + public override RowMapperColumnInfo[] GetOutputColumns() + { + var result = new RowMapperColumnInfo[_parent.ColumnPairs.Length]; + for (int i = 0; i < _parent.ColumnPairs.Length; i++) + result[i] = new RowMapperColumnInfo(_parent.ColumnPairs[i].output, _types[i], default); + return result; + } + + protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer) + { + Host.AssertValue(input); + Host.Assert(0 <= iinfo && iinfo < _infos.Length); + disposer = null; + + if (!_infos[iinfo].TypeSrc.IsVector) + return ComposeGetterOne(input, iinfo); + return ComposeGetterVec(input, iinfo); + } + + /// + /// Getter generator for single valued inputs. + /// + private ValueGetter ComposeGetterOne(IRow input, int iinfo) + => Utils.MarshalInvoke(ComposeGetterOne, _infos[iinfo].TypeSrc.RawType, input, iinfo); + + private ValueGetter ComposeGetterOne(IRow input, int iinfo) + { + var getSrc = input.GetGetter(ColMapNewToOld[iinfo]); + var src = default(T); + var isNA = (RefPredicate)_isNAs[iinfo]; + + ValueGetter getter; + + return getter = + (ref bool dst) => + { + getSrc(ref src); + dst = isNA(ref src); + }; + } + + /// + /// Getter generator for vector valued inputs. + /// + private ValueGetter> ComposeGetterVec(IRow input, int iinfo) + => Utils.MarshalInvoke(ComposeGetterVec, _infos[iinfo].TypeSrc.ItemType.RawType, input, iinfo); + + private ValueGetter> ComposeGetterVec(IRow input, int iinfo) + { + var getSrc = input.GetGetter>(ColMapNewToOld[iinfo]); + var isNA = (RefPredicate)_isNAs[iinfo]; + var val = default(T); + var defaultIsNA = isNA(ref val); + var src = default(VBuffer); + var indices = new List(); + + ValueGetter> getter; + + return getter = + (ref VBuffer dst) => + { + // Sense indicates if the values added to the indices list represent NAs or non-NAs. + bool sense; + getSrc(ref src); + _parent.FindNAs(ref src, isNA, defaultIsNA, indices, out sense); + _parent.FillValues(src.Length, ref dst, indices, sense); + }; + } + } + } + + public sealed class NAIndicatorEstimator : IEstimator + { + private readonly IHost _host; + private readonly NAIndicatorTransform.ColumnInfo[] _columns; + + public NAIndicatorEstimator(IHostEnvironment env, string name, string source = null) + : this(env, new NAIndicatorTransform.ColumnInfo(source ?? name, name)) + { + } + + public NAIndicatorEstimator(IHostEnvironment env, params NAIndicatorTransform.ColumnInfo[] columns) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(nameof(NAIndicatorEstimator)); + _columns = columns; + } + + public SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + _host.CheckValue(inputSchema, nameof(inputSchema)); + var result = inputSchema.Columns.ToDictionary(x => x.Name); + foreach (var colInfo in _columns) + { + if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + string reason = NAIndicatorTransform.TestType(col.ItemType); + if (reason != null) + throw _host.ExceptParam(nameof(inputSchema), reason); + var type = !col.ItemType.IsVector ? col.ItemType : new VectorType(col.ItemType.ItemType.AsPrimitive, col.ItemType.AsVector); + result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, col.Kind, type, false, default); + } + return new SchemaShape(result.Values); + } + + public NAIndicatorTransform Fit(IDataView input) => new NAIndicatorTransform(_host, input, _columns); + } + + /// + /// Extension methods for the static-pipeline over objects. + /// + public static class NAIndicatorExtensions + { + private interface IColInput + { + PipelineColumn Input { get; } + } + + private sealed class OutScalar : Scalar, IColInput + { + public PipelineColumn Input { get; } + + public OutScalar(Scalar input) + : base(Reconciler.Inst, input) + { + Input = input; + } + } + + private sealed class OutVectorColumn : Vector, IColInput + { + public PipelineColumn Input { get; } + + public OutVectorColumn(Vector input) + : base(Reconciler.Inst, input) + { + Input = input; + } + } + + private sealed class OutVarVectorColumn : VarVector, IColInput + { + public PipelineColumn Input { get; } + + public OutVarVectorColumn(VarVector input) + : base(Reconciler.Inst, input) + { + Input = input; + } + } + + private sealed class Reconciler : EstimatorReconciler + { + public static Reconciler Inst = new Reconciler(); + + private Reconciler() { } + + public override IEstimator Reconcile(IHostEnvironment env, + PipelineColumn[] toOutput, + IReadOnlyDictionary inputNames, + IReadOnlyDictionary outputNames, + IReadOnlyCollection usedNames) + { + var infos = new NAIndicatorTransform.ColumnInfo[toOutput.Length]; + for (int i = 0; i < toOutput.Length; ++i) + { + var col = (IColInput)toOutput[i]; + infos[i] = new NAIndicatorTransform.ColumnInfo(inputNames[col.Input], outputNames[toOutput[i]]); + } + return new NAIndicatorEstimator(env, infos); + } + } + + public static Scalar IsMissingValue(this Scalar input) + { + Contracts.CheckValue(input, nameof(input)); + return new OutScalar(input); + } + + public static Scalar IsMissingValue(this Scalar input) + { + Contracts.CheckValue(input, nameof(input)); + return new OutScalar(input); + } + + public static Scalar IsMissingValue(this Scalar input) + { + Contracts.CheckValue(input, nameof(input)); + return new OutScalar(input); + } + + public static Vector IsMissingValue(this Vector input) + { + Contracts.CheckValue(input, nameof(input)); + return new OutVectorColumn(input); + } + + public static Vector IsMissingValue(this Vector input) + { + Contracts.CheckValue(input, nameof(input)); + return new OutVectorColumn(input); + } + + public static Vector IsMissingValue(this Vector input) + { + Contracts.CheckValue(input, nameof(input)); + return new OutVectorColumn(input); + } + + public static VarVector IsMissingValue(this VarVector input) + { + Contracts.CheckValue(input, nameof(input)); + return new OutVarVectorColumn(input); + } + + public static VarVector IsMissingValue(this VarVector input) + { + Contracts.CheckValue(input, nameof(input)); + return new OutVarVectorColumn(input); + } + + public static VarVector IsMissingValue(this VarVector input) { + Contracts.CheckValue(input, nameof(input)); + return new OutVarVectorColumn(input); } } -} +} \ No newline at end of file diff --git a/test/Microsoft.ML.Tests/Transformers/NAIndicatorTests.cs b/test/Microsoft.ML.Tests/Transformers/NAIndicatorTests.cs new file mode 100644 index 0000000000..3fbebdae88 --- /dev/null +++ b/test/Microsoft.ML.Tests/Transformers/NAIndicatorTests.cs @@ -0,0 +1,130 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Data.IO; +using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Runtime.RunTests; +using Microsoft.ML.Runtime.Tools; +using System.IO; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.Tests.Transformers +{ + public class NAIndicatorTests : TestDataPipeBase + { + private class TestClass + { + public float A; + public double B; + [VectorType(2)] + public float[] C; + [VectorType(2)] + public double[] D; + } + + public NAIndicatorTests(ITestOutputHelper output) : base(output) + { + } + + [Fact] + public void NAIndicatorWorkout() + { + var data = new[] { + new TestClass() { A = 1, B = 3, C = new float[2]{ 1, 2 } , D = new double[2]{ 3,4} }, + new TestClass() { A = float.NaN, B = double.NaN, C = new float[2]{ float.NaN, float.NaN } , D = new double[2]{ double.NaN,double.NaN}}, + new TestClass() { A = float.NegativeInfinity, B = double.NegativeInfinity, C = new float[2]{ float.NegativeInfinity, float.NegativeInfinity } , D = new double[2]{ double.NegativeInfinity, double.NegativeInfinity}}, + new TestClass() { A = float.PositiveInfinity, B = double.PositiveInfinity, C = new float[2]{ float.PositiveInfinity, float.PositiveInfinity, } , D = new double[2]{ double.PositiveInfinity, double.PositiveInfinity}}, + new TestClass() { A = 2, B = 1, C = new float[2]{ 3, 4 } , D = new double[2]{ 5,6}}, + }; + + var dataView = ComponentCreation.CreateDataView(Env, data); + var pipe = new NAIndicatorEstimator(Env, + new NAIndicatorTransform.ColumnInfo("A", "NAA"), + new NAIndicatorTransform.ColumnInfo("B", "NAC"), + new NAIndicatorTransform.ColumnInfo("C", "NAD"), + new NAIndicatorTransform.ColumnInfo("D", "NAE")); + // write a simple test with one NAindicatortransform and try to inspect the columns using pete's code. might be easier! + //TestEstimatorCore(pipe, dataView); + Done(); + } + + [Fact] + public void NAIndicatorStatic() + { + string dataPath = GetDataPath("breast-cancer.txt"); + var reader = TextLoader.CreateReader(Env, ctx => ( + ScalarString: ctx.LoadText(1), + ScalarFloat: ctx.LoadFloat(1), + ScalarDouble: ctx.LoadDouble(1), + VectorString: ctx.LoadText(1, 4), + VectorFloat: ctx.LoadFloat(1, 4), + VectorDoulbe: ctx.LoadDouble(1, 4) + )); + + var data = reader.Read(new MultiFileSource(dataPath)); + var wrongCollection = new[] { new TestClass() { A = 1, B = 3, C = new float[2] { 1, 2 }, D = new double[2] { 3, 4 } } }; + var invalidData = ComponentCreation.CreateDataView(Env, wrongCollection); + + var est = data.MakeNewEstimator(). + Append(row => ( + A: row.ScalarString.IsMissingValue(), + B: row.ScalarDouble.IsMissingValue(), + C: row.VectorString.IsMissingValue(), + D: row.VectorFloat.IsMissingValue(), + F: row.VectorDoulbe.IsMissingValue() + )); + + TestEstimatorCore(est.AsDynamic, data.AsDynamic, invalidInput: invalidData); + var outputPath = GetOutputPath("NAIndicator", "featurized.tsv"); + using (var ch = Env.Start("save")) + { + var saver = new TextSaver(Env, new TextSaver.Arguments { Silent = true }); + IDataView savedData = TakeFilter.Create(Env, est.Fit(data).Transform(data).AsDynamic, 4); + savedData = new ChooseColumnsTransform(Env, savedData, "A", "B", "C", "D"); + using (var fs = File.Create(outputPath)) + DataSaverUtils.SaveDataView(ch, saver, savedData, fs, keepHidden: true); + } + + CheckEquality("NAIndicator", "featurized.tsv"); + Done(); + } + + [Fact] + public void TestCommandLine() + { + Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=NAIndicator{col=B:A} in=f:\2.txt" }), (int)0); + } + + [Fact] + public void TestOldSavingAndLoading() + { + var data = new[] { + new TestClass() { A = 1, B = 3, C = new float[2]{ 1, 2 } , D = new double[2]{ 3,4} }, + new TestClass() { A = float.NaN, B = double.NaN, C = new float[2]{ float.NaN, float.NaN } , D = new double[2]{ double.NaN,double.NaN}}, + new TestClass() { A = float.NegativeInfinity, B = double.NegativeInfinity, C = new float[2]{ float.NegativeInfinity, float.NegativeInfinity } , D = new double[2]{ double.NegativeInfinity, double.NegativeInfinity}}, + new TestClass() { A = float.PositiveInfinity, B = double.PositiveInfinity, C = new float[2]{ float.PositiveInfinity, float.PositiveInfinity, } , D = new double[2]{ double.PositiveInfinity, double.PositiveInfinity}}, + new TestClass() { A = 2, B = 1 , C = new float[2]{ 3, 4 } , D = new double[2]{ 5,6}}, + }; + + var dataView = ComponentCreation.CreateDataView(Env, data); + var pipe = new NAIndicatorEstimator(Env, + new NAIndicatorTransform.ColumnInfo("A", "NAA"), + new NAIndicatorTransform.ColumnInfo("B", "NAC"), + new NAIndicatorTransform.ColumnInfo("C", "NAD"), + new NAIndicatorTransform.ColumnInfo("D", "NAE")); + + var result = pipe.Fit(dataView).Transform(dataView); + var resultRoles = new RoleMappedData(result); + using (var ms = new MemoryStream()) + { + TrainUtils.SaveModel(Env, Env.Start("saving"), ms, null, resultRoles); + ms.Position = 0; + var loadedView = ModelFileUtils.LoadTransforms(Env, dataView, ms); + } + } + } +} From 9c1e7584c22f8086e33feb1c15bd42430064c493 Mon Sep 17 00:00:00 2001 From: Artidoro Pagnoni Date: Wed, 10 Oct 2018 14:37:14 -0700 Subject: [PATCH 04/13] finished debugging tests --- .../NAIndicatorTransform.cs | 54 +++++++++++-------- .../SingleDebug/NAIndicator/featurized.tsv | 17 ++++++ .../SingleRelease/NAIndicator/featurized.tsv | 17 ++++++ .../Transformers/NAIndicatorTests.cs | 42 +++++++++++---- 4 files changed, 98 insertions(+), 32 deletions(-) create mode 100644 test/BaselineOutput/SingleDebug/NAIndicator/featurized.tsv create mode 100644 test/BaselineOutput/SingleRelease/NAIndicator/featurized.tsv diff --git a/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs b/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs index 1b101f70ad..eb88b497f4 100644 --- a/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs +++ b/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs @@ -89,7 +89,6 @@ internal static string TestType(ColumnType type) type, LoadName); } - // TODO: Why do we even need an object for this? maybe because we want to deal with array of these things public class ColumnInfo { public readonly string Input; @@ -156,7 +155,7 @@ protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCo /// /// Host Environment. /// Input . This is the output from previous transform or loader. - /// TODO + /// Specifies the names of the input columns for the transform and the resulting output column names. public NAIndicatorTransform(IHostEnvironment env, IDataView input, params ColumnInfo[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(NAIndicatorTransform)), GetColumnPairs(columns)) { @@ -178,7 +177,9 @@ private NAIndicatorTransform(IHost host, ModelLoadContext ctx) _inputTypes = null; } - // Factory method for SignatureDataTransform. + /// + /// Factory method for SignatureDataTransform. + /// public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) { Contracts.CheckValue(env, nameof(env)); @@ -196,6 +197,9 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV return new NAIndicatorTransform(env, input, cols).MakeDataTransform(input); } + /// + /// Factory method for SignatureDataTransform. + /// public static NAIndicatorTransform Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); @@ -207,14 +211,21 @@ public static NAIndicatorTransform Create(IHostEnvironment env, ModelLoadContext return new NAIndicatorTransform(host, ctx); } + /// + /// Factory method for SignatureDataTransform. + /// public static IDataTransform Create(IHostEnvironment env, IDataView input, params ColumnInfo[] columns) => new NAIndicatorTransform(env, input, columns).MakeDataTransform(input); - // Factory method for SignatureLoadDataTransform. + /// + /// Factory method for SignatureDataTransform. + /// public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) => Create(env, ctx).MakeDataTransform(input); - // Factory method for SignatureLoadRowMapper. + /// + /// Factory method for SignatureDataTransform. + /// public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) => Create(env, ctx).MakeRowMapper(inputSchema); @@ -254,7 +265,6 @@ private Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action di Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < ColumnPairs.Length); disposer = null; - // TODO: is there a better way then checking every time in the getgetter? if (_outputTypes == null || _inputTypes == null) (_inputTypes, _outputTypes) = GetTypes(input.Schema); @@ -269,7 +279,7 @@ private Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action di private ValueGetter ComposeGetterOne(IRow input, int iinfo) { Func> func = ComposeGetterOne; - return Utils.MarshalInvoke(func, _inputTypes[iinfo].RawType, input, iinfo); + return Utils.MarshalInvoke(func, _outputTypes[iinfo].RawType, input, iinfo); } private ValueGetter ComposeGetterOne(IRow input, int iinfo) @@ -294,7 +304,7 @@ private ValueGetter ComposeGetterOne(IRow input, int iinfo) private ValueGetter> ComposeGetterVec(IRow input, int iinfo) { Func>> func = ComposeGetterVec; - return Utils.MarshalInvoke(func, _inputTypes[iinfo].ItemType.RawType, input, iinfo); + return Utils.MarshalInvoke(func, _outputTypes[iinfo].ItemType.RawType, input, iinfo); } private ValueGetter> ComposeGetterVec(IRow input, int iinfo) @@ -506,8 +516,6 @@ public Mapper(NAIndicatorTransform parent, ISchema inputSchema) _types[i] = BoolType.Instance; else _types[i] = new VectorType(BoolType.Instance, type.AsVector); - - _types[i] = type; _isNAs[i] = _parent.GetIsNADelegate(type); } } @@ -626,7 +634,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) string reason = NAIndicatorTransform.TestType(col.ItemType); if (reason != null) throw _host.ExceptParam(nameof(inputSchema), reason); - var type = !col.ItemType.IsVector ? col.ItemType : new VectorType(col.ItemType.ItemType.AsPrimitive, col.ItemType.AsVector); + ColumnType type = !col.ItemType.IsVector ? (ColumnType) BoolType.Instance : new VectorType(BoolType.Instance, col.ItemType.AsVector); result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, col.Kind, type, false, default); } return new SchemaShape(result.Values); @@ -645,7 +653,7 @@ private interface IColInput PipelineColumn Input { get; } } - private sealed class OutScalar : Scalar, IColInput + private sealed class OutScalar : Scalar, IColInput { public PipelineColumn Input { get; } @@ -656,7 +664,7 @@ public OutScalar(Scalar input) } } - private sealed class OutVectorColumn : Vector, IColInput + private sealed class OutVectorColumn : Vector, IColInput { public PipelineColumn Input { get; } @@ -667,7 +675,7 @@ public OutVectorColumn(Vector input) } } - private sealed class OutVarVectorColumn : VarVector, IColInput + private sealed class OutVarVectorColumn : VarVector, IColInput { public PipelineColumn Input { get; } @@ -700,55 +708,55 @@ public override IEstimator Reconcile(IHostEnvironment env, } } - public static Scalar IsMissingValue(this Scalar input) + public static Scalar IsMissingValue(this Scalar input) { Contracts.CheckValue(input, nameof(input)); return new OutScalar(input); } - public static Scalar IsMissingValue(this Scalar input) + public static Scalar IsMissingValue(this Scalar input) { Contracts.CheckValue(input, nameof(input)); return new OutScalar(input); } - public static Scalar IsMissingValue(this Scalar input) + public static Scalar IsMissingValue(this Scalar input) { Contracts.CheckValue(input, nameof(input)); return new OutScalar(input); } - public static Vector IsMissingValue(this Vector input) + public static Vector IsMissingValue(this Vector input) { Contracts.CheckValue(input, nameof(input)); return new OutVectorColumn(input); } - public static Vector IsMissingValue(this Vector input) + public static Vector IsMissingValue(this Vector input) { Contracts.CheckValue(input, nameof(input)); return new OutVectorColumn(input); } - public static Vector IsMissingValue(this Vector input) + public static Vector IsMissingValue(this Vector input) { Contracts.CheckValue(input, nameof(input)); return new OutVectorColumn(input); } - public static VarVector IsMissingValue(this VarVector input) + public static VarVector IsMissingValue(this VarVector input) { Contracts.CheckValue(input, nameof(input)); return new OutVarVectorColumn(input); } - public static VarVector IsMissingValue(this VarVector input) + public static VarVector IsMissingValue(this VarVector input) { Contracts.CheckValue(input, nameof(input)); return new OutVarVectorColumn(input); } - public static VarVector IsMissingValue(this VarVector input) + public static VarVector IsMissingValue(this VarVector input) { Contracts.CheckValue(input, nameof(input)); return new OutVarVectorColumn(input); diff --git a/test/BaselineOutput/SingleDebug/NAIndicator/featurized.tsv b/test/BaselineOutput/SingleDebug/NAIndicator/featurized.tsv new file mode 100644 index 0000000000..f4dcf44107 --- /dev/null +++ b/test/BaselineOutput/SingleDebug/NAIndicator/featurized.tsv @@ -0,0 +1,17 @@ +#@ TextLoader{ +#@ header+ +#@ sep=tab +#@ col=ScalarFloat:R4:0 +#@ col=ScalarDouble:R8:1 +#@ col=VectorFloat:R4:2-5 +#@ col=VectorDoulbe:R8:6-9 +#@ col=A:BL:10 +#@ col=B:BL:11 +#@ col=C:BL:12-15 +#@ col=D:BL:16-19 +#@ } +ScalarFloat ScalarDouble 18 8:A 9:B +5 5 5 1 1 1 5 1 1 1 10 0:0 +5 5 5 4 4 5 5 4 4 5 10 0:0 +3 3 3 1 1 1 3 1 1 1 10 0:0 +6 6 6 8 8 1 6 8 8 1 10 0:0 diff --git a/test/BaselineOutput/SingleRelease/NAIndicator/featurized.tsv b/test/BaselineOutput/SingleRelease/NAIndicator/featurized.tsv new file mode 100644 index 0000000000..f4dcf44107 --- /dev/null +++ b/test/BaselineOutput/SingleRelease/NAIndicator/featurized.tsv @@ -0,0 +1,17 @@ +#@ TextLoader{ +#@ header+ +#@ sep=tab +#@ col=ScalarFloat:R4:0 +#@ col=ScalarDouble:R8:1 +#@ col=VectorFloat:R4:2-5 +#@ col=VectorDoulbe:R8:6-9 +#@ col=A:BL:10 +#@ col=B:BL:11 +#@ col=C:BL:12-15 +#@ col=D:BL:16-19 +#@ } +ScalarFloat ScalarDouble 18 8:A 9:B +5 5 5 1 1 1 5 1 1 1 10 0:0 +5 5 5 4 4 5 5 4 4 5 10 0:0 +3 3 3 1 1 1 3 1 1 1 10 0:0 +6 6 6 8 8 1 6 8 8 1 10 0:0 diff --git a/test/Microsoft.ML.Tests/Transformers/NAIndicatorTests.cs b/test/Microsoft.ML.Tests/Transformers/NAIndicatorTests.cs index 3fbebdae88..ed66733027 100644 --- a/test/Microsoft.ML.Tests/Transformers/NAIndicatorTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/NAIndicatorTests.cs @@ -2,13 +2,16 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Data; using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.RunTests; using Microsoft.ML.Runtime.Tools; +using System; using System.IO; +using System.Linq; using Xunit; using Xunit.Abstractions; @@ -47,20 +50,43 @@ public void NAIndicatorWorkout() new NAIndicatorTransform.ColumnInfo("B", "NAC"), new NAIndicatorTransform.ColumnInfo("C", "NAD"), new NAIndicatorTransform.ColumnInfo("D", "NAE")); - // write a simple test with one NAindicatortransform and try to inspect the columns using pete's code. might be easier! - //TestEstimatorCore(pipe, dataView); + TestEstimatorCore(pipe, dataView); Done(); } + //[Fact] + //public void NAIndicatorSimpleWorkout() + //{ + // var data = new[] { + // new TestClass() { A = 1, B = 3, C = new float[2]{ 1, 2 } , D = new double[2]{ 3,4} }, + // new TestClass() { A = float.NaN, B = double.NaN, C = new float[2]{ float.NaN, float.NaN } , D = new double[2]{ double.NaN,double.NaN}}, + // new TestClass() { A = float.NegativeInfinity, B = double.NegativeInfinity, C = new float[2]{ float.NegativeInfinity, float.NegativeInfinity } , D = new double[2]{ double.NegativeInfinity, double.NegativeInfinity}}, + // new TestClass() { A = float.PositiveInfinity, B = double.PositiveInfinity, C = new float[2]{ float.PositiveInfinity, float.PositiveInfinity, } , D = new double[2]{ double.PositiveInfinity, double.PositiveInfinity}}, + // new TestClass() { A = 2, B = 1, C = new float[2]{ 3, 4 } , D = new double[2]{ 5,6}}, + // }; + + // var dataView = ComponentCreation.CreateDataView(Env, data); + // var pipe = new NAIndicatorEstimator(Env, + // new NAIndicatorTransform.ColumnInfo("A", "AA")); + // var transform = pipe.Fit(dataView); + // var outDataView = transform.Transform(dataView); + // var col = outDataView.GetColumn(Env, "AA").ToArray(); + // // write a simple test with one NAindicatortransform and try to inspect the columns using pete's code. might be easier! + // //TestEstimatorCore(pipe, dataView); + // foreach(bool c in col) + // { + // Console.WriteLine(c.ToString()); + // } + // Done(); + //} + [Fact] public void NAIndicatorStatic() { string dataPath = GetDataPath("breast-cancer.txt"); var reader = TextLoader.CreateReader(Env, ctx => ( - ScalarString: ctx.LoadText(1), ScalarFloat: ctx.LoadFloat(1), ScalarDouble: ctx.LoadDouble(1), - VectorString: ctx.LoadText(1, 4), VectorFloat: ctx.LoadFloat(1, 4), VectorDoulbe: ctx.LoadDouble(1, 4) )); @@ -71,11 +97,10 @@ public void NAIndicatorStatic() var est = data.MakeNewEstimator(). Append(row => ( - A: row.ScalarString.IsMissingValue(), + A: row.ScalarFloat.IsMissingValue(), B: row.ScalarDouble.IsMissingValue(), - C: row.VectorString.IsMissingValue(), - D: row.VectorFloat.IsMissingValue(), - F: row.VectorDoulbe.IsMissingValue() + C: row.VectorFloat.IsMissingValue(), + D: row.VectorDoulbe.IsMissingValue() )); TestEstimatorCore(est.AsDynamic, data.AsDynamic, invalidInput: invalidData); @@ -84,7 +109,6 @@ public void NAIndicatorStatic() { var saver = new TextSaver(Env, new TextSaver.Arguments { Silent = true }); IDataView savedData = TakeFilter.Create(Env, est.Fit(data).Transform(data).AsDynamic, 4); - savedData = new ChooseColumnsTransform(Env, savedData, "A", "B", "C", "D"); using (var fs = File.Create(outputPath)) DataSaverUtils.SaveDataView(ch, saver, savedData, fs, keepHidden: true); } From 4062bedc98e3f340353e7d2c3e8c93c1227af0c0 Mon Sep 17 00:00:00 2001 From: Artidoro Pagnoni Date: Wed, 10 Oct 2018 14:47:07 -0700 Subject: [PATCH 05/13] cleanup --- .../NAIndicatorTransform.cs | 1 - .../Transformers/NAIndicatorTests.cs | 41 +++---------------- 2 files changed, 6 insertions(+), 36 deletions(-) diff --git a/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs b/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs index eb88b497f4..528b6b238d 100644 --- a/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs +++ b/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs @@ -14,7 +14,6 @@ using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.Model.Onnx; using Microsoft.ML.StaticPipe; using Microsoft.ML.StaticPipe.Runtime; diff --git a/test/Microsoft.ML.Tests/Transformers/NAIndicatorTests.cs b/test/Microsoft.ML.Tests/Transformers/NAIndicatorTests.cs index ed66733027..98ab1ab93a 100644 --- a/test/Microsoft.ML.Tests/Transformers/NAIndicatorTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/NAIndicatorTests.cs @@ -2,16 +2,13 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Data; using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.RunTests; using Microsoft.ML.Runtime.Tools; -using System; using System.IO; -using System.Linq; using Xunit; using Xunit.Abstractions; @@ -47,39 +44,13 @@ public void NAIndicatorWorkout() var dataView = ComponentCreation.CreateDataView(Env, data); var pipe = new NAIndicatorEstimator(Env, new NAIndicatorTransform.ColumnInfo("A", "NAA"), - new NAIndicatorTransform.ColumnInfo("B", "NAC"), - new NAIndicatorTransform.ColumnInfo("C", "NAD"), - new NAIndicatorTransform.ColumnInfo("D", "NAE")); + new NAIndicatorTransform.ColumnInfo("B", "NAB"), + new NAIndicatorTransform.ColumnInfo("C", "NAC"), + new NAIndicatorTransform.ColumnInfo("D", "NAD")); TestEstimatorCore(pipe, dataView); Done(); } - //[Fact] - //public void NAIndicatorSimpleWorkout() - //{ - // var data = new[] { - // new TestClass() { A = 1, B = 3, C = new float[2]{ 1, 2 } , D = new double[2]{ 3,4} }, - // new TestClass() { A = float.NaN, B = double.NaN, C = new float[2]{ float.NaN, float.NaN } , D = new double[2]{ double.NaN,double.NaN}}, - // new TestClass() { A = float.NegativeInfinity, B = double.NegativeInfinity, C = new float[2]{ float.NegativeInfinity, float.NegativeInfinity } , D = new double[2]{ double.NegativeInfinity, double.NegativeInfinity}}, - // new TestClass() { A = float.PositiveInfinity, B = double.PositiveInfinity, C = new float[2]{ float.PositiveInfinity, float.PositiveInfinity, } , D = new double[2]{ double.PositiveInfinity, double.PositiveInfinity}}, - // new TestClass() { A = 2, B = 1, C = new float[2]{ 3, 4 } , D = new double[2]{ 5,6}}, - // }; - - // var dataView = ComponentCreation.CreateDataView(Env, data); - // var pipe = new NAIndicatorEstimator(Env, - // new NAIndicatorTransform.ColumnInfo("A", "AA")); - // var transform = pipe.Fit(dataView); - // var outDataView = transform.Transform(dataView); - // var col = outDataView.GetColumn(Env, "AA").ToArray(); - // // write a simple test with one NAindicatortransform and try to inspect the columns using pete's code. might be easier! - // //TestEstimatorCore(pipe, dataView); - // foreach(bool c in col) - // { - // Console.WriteLine(c.ToString()); - // } - // Done(); - //} - [Fact] public void NAIndicatorStatic() { @@ -137,9 +108,9 @@ public void TestOldSavingAndLoading() var dataView = ComponentCreation.CreateDataView(Env, data); var pipe = new NAIndicatorEstimator(Env, new NAIndicatorTransform.ColumnInfo("A", "NAA"), - new NAIndicatorTransform.ColumnInfo("B", "NAC"), - new NAIndicatorTransform.ColumnInfo("C", "NAD"), - new NAIndicatorTransform.ColumnInfo("D", "NAE")); + new NAIndicatorTransform.ColumnInfo("B", "NAB"), + new NAIndicatorTransform.ColumnInfo("C", "NAC"), + new NAIndicatorTransform.ColumnInfo("D", "NAD")); var result = pipe.Fit(dataView).Transform(dataView); var resultRoles = new RoleMappedData(result); From 76a893aaa8b08e0ad74af3f77326705f44fb7e78 Mon Sep 17 00:00:00 2001 From: Artidoro Pagnoni Date: Wed, 10 Oct 2018 18:27:06 -0700 Subject: [PATCH 06/13] fixing an issue --- .../NAIndicatorTransform.cs | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs b/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs index 528b6b238d..fac198ca85 100644 --- a/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs +++ b/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs @@ -719,12 +719,6 @@ public static Scalar IsMissingValue(this Scalar input) return new OutScalar(input); } - public static Scalar IsMissingValue(this Scalar input) - { - Contracts.CheckValue(input, nameof(input)); - return new OutScalar(input); - } - public static Vector IsMissingValue(this Vector input) { Contracts.CheckValue(input, nameof(input)); @@ -737,12 +731,6 @@ public static Vector IsMissingValue(this Vector input) return new OutVectorColumn(input); } - public static Vector IsMissingValue(this Vector input) - { - Contracts.CheckValue(input, nameof(input)); - return new OutVectorColumn(input); - } - public static VarVector IsMissingValue(this VarVector input) { Contracts.CheckValue(input, nameof(input)); @@ -754,11 +742,5 @@ public static VarVector IsMissingValue(this VarVector input) Contracts.CheckValue(input, nameof(input)); return new OutVarVectorColumn(input); } - - public static VarVector IsMissingValue(this VarVector input) - { - Contracts.CheckValue(input, nameof(input)); - return new OutVarVectorColumn(input); - } } } \ No newline at end of file From 8f331fd507f441bc4b71425ae6227c04f2723160 Mon Sep 17 00:00:00 2001 From: Artidoro Pagnoni Date: Mon, 15 Oct 2018 17:51:25 -0700 Subject: [PATCH 07/13] fixed review comments --- .../NAHandleTransform.cs | 1 + src/Microsoft.ML.Transforms/NAHandling.cs | 3 +- .../NAIndicatorTransform.cs | 719 ++++++++---------- .../StaticPipeTests.cs | 47 ++ .../Transformers/NAIndicatorTests.cs | 49 +- 5 files changed, 351 insertions(+), 468 deletions(-) diff --git a/src/Microsoft.ML.Transforms/NAHandleTransform.cs b/src/Microsoft.ML.Transforms/NAHandleTransform.cs index 8645d72647..3c777e03a0 100644 --- a/src/Microsoft.ML.Transforms/NAHandleTransform.cs +++ b/src/Microsoft.ML.Transforms/NAHandleTransform.cs @@ -11,6 +11,7 @@ using Microsoft.ML.Runtime.Data.Conversion; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Transforms; [assembly: LoadableClass(NAHandleTransform.Summary, typeof(IDataTransform), typeof(NAHandleTransform), typeof(NAHandleTransform.Arguments), typeof(SignatureDataTransform), NAHandleTransform.FriendlyName, "NAHandleTransform", NAHandleTransform.ShortName, "NA", DocName = "transform/NAHandle.md")] diff --git a/src/Microsoft.ML.Transforms/NAHandling.cs b/src/Microsoft.ML.Transforms/NAHandling.cs index 8a6635ef77..79d3b49d5a 100644 --- a/src/Microsoft.ML.Transforms/NAHandling.cs +++ b/src/Microsoft.ML.Transforms/NAHandling.cs @@ -4,6 +4,7 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.Transforms; [assembly: EntryPointModule(typeof(NAHandling))] @@ -71,7 +72,7 @@ public static CommonOutputs.TransformOutput Handle(IHostEnvironment env, NAHandl public static CommonOutputs.TransformOutput Indicator(IHostEnvironment env, NAIndicatorTransform.Arguments input) { var h = EntryPointUtils.CheckArgsAndCreateHost(env, "NAIndicator", input); - var xf = NAIndicatorTransform.Create(h, input, input.Data); + var xf = new NAIndicatorTransform(h, input).Transform(input.Data); return new CommonOutputs.TransformOutput() { Model = new TransformModel(h, xf, input.Data), diff --git a/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs b/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs index fac198ca85..6d2dcd9f9d 100644 --- a/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs +++ b/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs @@ -16,20 +16,21 @@ using Microsoft.ML.Runtime.Model; using Microsoft.ML.StaticPipe; using Microsoft.ML.StaticPipe.Runtime; +using Microsoft.ML.Transforms; [assembly: LoadableClass(NAIndicatorTransform.Summary, typeof(IDataTransform), typeof(NAIndicatorTransform), typeof(NAIndicatorTransform.Arguments), typeof(SignatureDataTransform), - NAIndicatorTransform.FriendlyName, NAIndicatorTransform.LoadName, "NAIndicator", NAIndicatorTransform.ShortName, DocName = "transform/NAHandle.md")] + NAIndicatorTransform.FriendlyName, nameof(NAIndicatorTransform), "NAIndicator", NAIndicatorTransform.ShortName, DocName = "transform/NAHandle.md")] [assembly: LoadableClass(NAIndicatorTransform.Summary, typeof(IDataTransform), typeof(NAIndicatorTransform), null, typeof(SignatureLoadDataTransform), - NAIndicatorTransform.FriendlyName, NAIndicatorTransform.LoadName)] + NAIndicatorTransform.FriendlyName, nameof(NAIndicatorTransform))] [assembly: LoadableClass(NAIndicatorTransform.Summary, typeof(NAIndicatorTransform), null, typeof(SignatureLoadModel), - NAIndicatorTransform.FriendlyName, NAIndicatorTransform.LoadName)] + NAIndicatorTransform.FriendlyName, nameof(NAIndicatorTransform))] [assembly: LoadableClass(typeof(IRowMapper), typeof(NAIndicatorTransform), null, typeof(SignatureLoadRowMapper), - NAIndicatorTransform.FriendlyName, NAIndicatorTransform.LoadName)] + NAIndicatorTransform.FriendlyName, nameof(NAIndicatorTransform))] -namespace Microsoft.ML.Runtime.Data +namespace Microsoft.ML.Transforms { /// public sealed class NAIndicatorTransform : OneToOneTransformerBase @@ -59,17 +60,14 @@ public sealed class Arguments : TransformInputBase public Column[] Column; } - public const string LoadName = "NaIndicatorTransform"; - private static VersionInfo GetVersionInfo() { return new VersionInfo( - // REVIEW: temporary name modelSignature: "NAIND TF", verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: LoadName, + loaderSignature: nameof(NAIndicatorTransform), loaderAssemblyName: typeof(NAIndicatorTransform).Assembly.FullName); } @@ -78,170 +76,66 @@ private static VersionInfo GetVersionInfo() internal const string FriendlyName = "NA Indicator Transform"; internal const string ShortName = "NAInd"; - internal static string TestType(ColumnType type) - { - // Item type must have an NA value. We'll get the predicate again later when we're ready to use it. - Delegate del; - if (Conversions.Instance.TryGetIsNAPredicate(type.ItemType, out del)) - return null; - return string.Format("Type '{0}' is not supported by {1} since it doesn't have an NA value", - type, LoadName); - } - - public class ColumnInfo - { - public readonly string Input; - public readonly string Output; - - /// - /// Describes how the transformer handles one column pair. - /// - /// Name of input column. - /// Name of output column. - public ColumnInfo(string input, string output) - { - Input = input; - Output = output; - } - } - private const string RegistrationName = nameof(NAIndicatorTransform); - // The input column types - private ColumnType[] _inputTypes; - // The output column types - private ColumnType[] _outputTypes; - - private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) - { - Contracts.CheckValue(columns, nameof(columns)); - return columns.Select(x => (x.Input, x.Output)).ToArray(); - } - - protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) - { - var type = inputSchema.GetColumnType(srcCol); - string reason = TestType(type); - if (reason != null) - throw Host.ExceptParam(nameof(inputSchema), reason); - } - - private (ColumnType[], ColumnType[]) GetTypes(ISchema schema) - { - var inputTypes = new ColumnType[ColumnPairs.Length]; - var outputTypes = new ColumnType[ColumnPairs.Length]; - for (int i = 0; i < ColumnPairs.Length; i++) - { - schema.TryGetColumnIndex(ColumnPairs[i].input, out int colSrc); - var type = schema.GetColumnType(colSrc); - - if (!type.IsVector) - { - inputTypes[i] = type.AsPrimitive; - outputTypes[i] = BoolType.Instance; - } - else - { - inputTypes[i] = new VectorType(type.ItemType.AsPrimitive, type.AsVector); - outputTypes[i] = new VectorType(BoolType.Instance, type.AsVector); - } - } - return (inputTypes, outputTypes); - } - /// - /// Convenience constructor for public facing API. + /// Initializes a new instance of /// - /// Host Environment. - /// Input . This is the output from previous transform or loader. - /// Specifies the names of the input columns for the transform and the resulting output column names. - public NAIndicatorTransform(IHostEnvironment env, IDataView input, params ColumnInfo[] columns) - : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(NAIndicatorTransform)), GetColumnPairs(columns)) + /// The environment to use. + /// The names of the input columns of the transformation and the corresponding names for the output columns. + internal NAIndicatorTransform(IHostEnvironment env, params (string input, string output)[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(NAIndicatorTransform)), columns) { - // Check that all the input columns are present and correct. - for (int i = 0; i < ColumnPairs.Length; i++) - { - if (!input.Schema.TryGetColumnIndex(ColumnPairs[i].input, out int srcCol)) - throw Host.ExceptSchemaMismatch(nameof(input), "input", ColumnPairs[i].input); - CheckInputColumn(input.Schema, i, srcCol); - } - (_inputTypes, _outputTypes) = GetTypes(input.Schema); } - private NAIndicatorTransform(IHost host, ModelLoadContext ctx) - : base(host, ctx) + internal NAIndicatorTransform(IHostEnvironment env, Arguments args) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(NAIndicatorTransform)), GetColumnPairs(args.Column)) { - Host.AssertValue(ctx); - _outputTypes = null; - _inputTypes = null; } - /// - /// Factory method for SignatureDataTransform. - /// - public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + private NAIndicatorTransform(IHostEnvironment env, ModelLoadContext ctx) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(NAIndicatorTransform)), ctx) { - Contracts.CheckValue(env, nameof(env)); - env.CheckValue(args, nameof(args)); - env.CheckValue(input, nameof(input)); + Host.CheckValue(ctx, nameof(ctx)); + } - env.CheckValue(args.Column, nameof(args.Column)); - var cols = new ColumnInfo[args.Column.Length]; + private static (string input, string output)[] GetColumnPairs(Column[] columns) + { + var cols = new (string input, string output)[columns.Length]; for (int i = 0; i < cols.Length; i++) { - var item = args.Column[i]; + var item = columns[i]; - cols[i] = new ColumnInfo(item.Source, item.Name); + cols[i].input = item.Source; + cols[i].output = item.Name; }; - return new NAIndicatorTransform(env, input, cols).MakeDataTransform(input); + return cols; } - /// - /// Factory method for SignatureDataTransform. - /// - public static NAIndicatorTransform Create(IHostEnvironment env, ModelLoadContext ctx) + // Factory method for SignatureLoadModel + internal static NAIndicatorTransform Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); - var host = env.Register(RegistrationName); - - host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return new NAIndicatorTransform(host, ctx); + return new NAIndicatorTransform(env, ctx); } - /// - /// Factory method for SignatureDataTransform. - /// - public static IDataTransform Create(IHostEnvironment env, IDataView input, params ColumnInfo[] columns) - => new NAIndicatorTransform(env, input, columns).MakeDataTransform(input); + // Factory method for SignatureDataTransform. + internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + => new NAIndicatorTransform(env, args).MakeDataTransform(input); - /// - /// Factory method for SignatureDataTransform. - /// - public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + // Factory method for SignatureLoadDataTransform. + internal static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) => Create(env, ctx).MakeDataTransform(input); - /// - /// Factory method for SignatureDataTransform. - /// - public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) + // Factory method for SignatureLoadRowMapper. + internal static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) => Create(env, ctx).MakeRowMapper(inputSchema); /// - /// Returns the isNA predicate for the respective type. + /// Saves the transform. /// - private Delegate GetIsNADelegate(ColumnType type) - { - Func func = GetIsNADelegate; - return Utils.MarshalInvoke(func, type.ItemType.RawType, type); - } - - private Delegate GetIsNADelegate(ColumnType type) - { - return Conversions.Instance.GetIsNAPredicate(type.ItemType); - } - public override void Save(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); @@ -252,230 +146,6 @@ public override void Save(ModelSaveContext ctx) SaveColumns(ctx); } - private ColumnType GetColumnTypeCore(int iinfo) - { - Host.Assert(0 <= iinfo & iinfo < ColumnPairs.Length); - return _outputTypes[iinfo]; - } - - private Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) - { - Host.AssertValueOrNull(ch); - Host.AssertValue(input); - Host.Assert(0 <= iinfo && iinfo < ColumnPairs.Length); - disposer = null; - if (_outputTypes == null || _inputTypes == null) - (_inputTypes, _outputTypes) = GetTypes(input.Schema); - - if (!_outputTypes[iinfo].IsVector) - return ComposeGetterOne(input, iinfo); - return ComposeGetterVec(input, iinfo); - } - - /// - /// Getter generator for single valued inputs. - /// - private ValueGetter ComposeGetterOne(IRow input, int iinfo) - { - Func> func = ComposeGetterOne; - return Utils.MarshalInvoke(func, _outputTypes[iinfo].RawType, input, iinfo); - } - - private ValueGetter ComposeGetterOne(IRow input, int iinfo) - { - Host.Assert(0 <= iinfo & iinfo < ColumnPairs.Length); - Host.Assert(input.IsColumnActive(iinfo)); - - var getSrc = input.GetGetter(iinfo); - var isNA = Conversions.Instance.GetIsNAPredicate(_inputTypes[iinfo]); - T src = default(T); - return - (ref bool dst) => - { - getSrc(ref src); - dst = isNA(ref src); - }; - } - - /// - /// Getter generator for vector valued inputs. - /// - private ValueGetter> ComposeGetterVec(IRow input, int iinfo) - { - Func>> func = ComposeGetterVec; - return Utils.MarshalInvoke(func, _outputTypes[iinfo].ItemType.RawType, input, iinfo); - } - - private ValueGetter> ComposeGetterVec(IRow input, int iinfo) - { - Host.Assert(0 <= iinfo & iinfo < ColumnPairs.Length); - Host.Assert(input.IsColumnActive(iinfo)); - - var getSrc = input.GetGetter>(iinfo); - var isNA = Conversions.Instance.GetIsNAPredicate(_inputTypes[iinfo]); - var val = default(T); - bool defaultIsNA = isNA(ref val); - var src = default(VBuffer); - var indices = new List(); - return - (ref VBuffer dst) => - { - // Sense indicates if the values added to the indices list represent NAs or non-NAs. - bool sense; - getSrc(ref src); - FindNAs(ref src, isNA, defaultIsNA, indices, out sense); - FillValues(src.Length, ref dst, indices, sense); - }; - } - - /// - /// Adds all NAs (or non-NAs) to the indices List. Whether NAs or non-NAs have been added is indicated by the bool sense. - /// - private void FindNAs(ref VBuffer src, RefPredicate isNA, bool defaultIsNA, List indices, out bool sense) - { - Host.AssertValue(isNA); - Host.AssertValue(indices); - - // Find the indices of all of the NAs. - indices.Clear(); - var srcValues = src.Values; - var srcCount = src.Count; - if (src.IsDense) - { - for (int i = 0; i < srcCount; i++) - { - if (isNA(ref srcValues[i])) - indices.Add(i); - } - sense = true; - } - else if (!defaultIsNA) - { - var srcIndices = src.Indices; - for (int ii = 0; ii < srcCount; ii++) - { - if (isNA(ref srcValues[ii])) - indices.Add(srcIndices[ii]); - } - sense = true; - } - else - { - // Note that this adds non-NAs to indices -- this is indicated by sense being false. - var srcIndices = src.Indices; - for (int ii = 0; ii < srcCount; ii++) - { - if (!isNA(ref srcValues[ii])) - indices.Add(srcIndices[ii]); - } - sense = false; - } - } - - /// - /// Fills indicator values for vectors. The indices is a list that either holds all of the NAs or all - /// of the non-NAs, indicated by sense being true or false respectively. - /// - private void FillValues(int srcLength, ref VBuffer dst, List indices, bool sense) - { - var dstValues = dst.Values; - var dstIndices = dst.Indices; - - if (indices.Count == 0) - { - if (sense) - { - // Return empty VBuffer. - dst = new VBuffer(srcLength, 0, dstValues, dstIndices); - return; - } - - // Return VBuffer filled with 1's. - Utils.EnsureSize(ref dstValues, srcLength, false); - for (int i = 0; i < srcLength; i++) - dstValues[i] = true; - dst = new VBuffer(srcLength, dstValues, dstIndices); - return; - } - - if (sense && indices.Count < srcLength / 2) - { - // Will produce sparse output. - int dstCount = indices.Count; - Utils.EnsureSize(ref dstValues, dstCount, false); - Utils.EnsureSize(ref dstIndices, dstCount, false); - - indices.CopyTo(dstIndices); - for (int ii = 0; ii < dstCount; ii++) - dstValues[ii] = true; - - Host.Assert(dstCount <= srcLength); - dst = new VBuffer(srcLength, dstCount, dstValues, dstIndices); - } - else if (!sense && srcLength - indices.Count < srcLength / 2) - { - // Will produce sparse output. - int dstCount = srcLength - indices.Count; - Utils.EnsureSize(ref dstValues, dstCount, false); - Utils.EnsureSize(ref dstIndices, dstCount, false); - - // Appends the length of the src to make the loop simpler, - // as the length of src will never be reached in the loop. - indices.Add(srcLength); - - int iiDst = 0; - int iiSrc = 0; - int iNext = indices[iiSrc]; - for (int i = 0; i < srcLength; i++) - { - Host.Assert(0 <= i && i <= iNext); - Host.Assert(iiSrc + iiDst == i); - if (i < iNext) - { - Host.Assert(iiDst < dstCount); - dstValues[iiDst] = true; - dstIndices[iiDst++] = i; - } - else - { - Host.Assert(iiSrc + 1 < indices.Count); - Host.Assert(iNext < indices[iiSrc + 1]); - iNext = indices[++iiSrc]; - } - } - Host.Assert(srcLength == iiSrc + iiDst); - Host.Assert(iiDst == dstCount); - - dst = new VBuffer(srcLength, dstCount, dstValues, dstIndices); - } - else - { - // Will produce dense output. - Utils.EnsureSize(ref dstValues, srcLength, false); - - // Appends the length of the src to make the loop simpler, - // as the length of src will never be reached in the loop. - indices.Add(srcLength); - - int ii = 0; - for (int i = 0; i < srcLength; i++) - { - Host.Assert(0 <= i && i <= indices[ii]); - if (i == indices[ii]) - { - dstValues[i] = sense; - ii++; - Host.Assert(ii < indices.Count); - Host.Assert(indices[ii - 1] < indices[ii]); - } - else - dstValues[i] = !sense; - } - - dst = new VBuffer(srcLength, dstValues, dstIndices); - } - } - protected override IRowMapper MakeRowMapper(ISchema schema) => new Mapper(this, schema); @@ -483,40 +153,31 @@ private sealed class Mapper : MapperBase { private sealed class ColInfo { - public readonly string Name; - public readonly string Source; - public readonly ColumnType TypeSrc; + public readonly string Output; + public readonly string Input; + public readonly ColumnType OutputType; + public readonly ColumnType InputType; + public readonly Delegate InputIsNA; - public ColInfo(string name, string source, ColumnType type) + public ColInfo(string input, string output, ColumnType inType, ColumnType outType) { - Name = name; - Source = source; - TypeSrc = type; + Input = input; + Output = output; + InputType = inType; + OutputType = outType; + InputIsNA = GetIsNADelegate(InputType); ; } } + // are we sure we need this? maybe not? private readonly NAIndicatorTransform _parent; private readonly ColInfo[] _infos; - private readonly ColumnType[] _types; - // The isNA delegates, parallel to Infos. - private readonly Delegate[] _isNAs; public Mapper(NAIndicatorTransform parent, ISchema inputSchema) : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) { _parent = parent; _infos = CreateInfos(inputSchema); - _types = new ColumnType[_parent.ColumnPairs.Length]; - _isNAs = new Delegate[_parent.ColumnPairs.Length]; - for (int i = 0; i < _parent.ColumnPairs.Length; i++) - { - var type = _infos[i].TypeSrc; - if (!type.IsVector) - _types[i] = BoolType.Instance; - else - _types[i] = new VectorType(BoolType.Instance, type.AsVector); - _isNAs[i] = _parent.GetIsNADelegate(type); - } } private ColInfo[] CreateInfos(ISchema inputSchema) @@ -528,8 +189,13 @@ private ColInfo[] CreateInfos(ISchema inputSchema) if (!inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colSrc)) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input); _parent.CheckInputColumn(inputSchema, i, colSrc); - var type = inputSchema.GetColumnType(colSrc); - infos[i] = new ColInfo(_parent.ColumnPairs[i].output, _parent.ColumnPairs[i].input, type); + var inType = inputSchema.GetColumnType(colSrc); + ColumnType outType; + if (!inType.IsVector) + outType = BoolType.Instance; + else + outType = new VectorType(BoolType.Instance, inType.AsVector); + infos[i] = new ColInfo(_parent.ColumnPairs[i].input, _parent.ColumnPairs[i].output, inType, outType); } return infos; } @@ -537,18 +203,38 @@ private ColInfo[] CreateInfos(ISchema inputSchema) public override RowMapperColumnInfo[] GetOutputColumns() { var result = new RowMapperColumnInfo[_parent.ColumnPairs.Length]; - for (int i = 0; i < _parent.ColumnPairs.Length; i++) - result[i] = new RowMapperColumnInfo(_parent.ColumnPairs[i].output, _types[i], default); + for (int iinfo = 0; iinfo < _infos.Length; iinfo++) + { + InputSchema.TryGetColumnIndex(_infos[iinfo].Input, out int colIndex); + Host.Assert(colIndex >= 0); + var colMetaInfo = new ColumnMetadataInfo(_infos[iinfo].Output); + var meta = RowColumnUtils.GetMetadataAsRow(InputSchema, colIndex, x => x == MetadataUtils.Kinds.SlotNames || x == MetadataUtils.Kinds.IsNormalized); + result[iinfo] = new RowMapperColumnInfo(_infos[iinfo].Output, _infos[iinfo].OutputType, meta); + } return result; } + /// + /// Returns the isNA predicate for the respective type. + /// + private static Delegate GetIsNADelegate(ColumnType type) + { + Func func = GetIsNADelegate; + return Utils.MarshalInvoke(func, type.ItemType.RawType, type); + } + + private static Delegate GetIsNADelegate(ColumnType type) + { + return Conversions.Instance.GetIsNAPredicate(type.ItemType); + } + protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer) { Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < _infos.Length); disposer = null; - if (!_infos[iinfo].TypeSrc.IsVector) + if (!_infos[iinfo].InputType.IsVector) return ComposeGetterOne(input, iinfo); return ComposeGetterVec(input, iinfo); } @@ -557,13 +243,13 @@ protected override Delegate MakeGetter(IRow input, int iinfo, out Action dispose /// Getter generator for single valued inputs. /// private ValueGetter ComposeGetterOne(IRow input, int iinfo) - => Utils.MarshalInvoke(ComposeGetterOne, _infos[iinfo].TypeSrc.RawType, input, iinfo); + => Utils.MarshalInvoke(ComposeGetterOne, _infos[iinfo].InputType.RawType, input, iinfo); private ValueGetter ComposeGetterOne(IRow input, int iinfo) { var getSrc = input.GetGetter(ColMapNewToOld[iinfo]); var src = default(T); - var isNA = (RefPredicate)_isNAs[iinfo]; + var isNA = (RefPredicate)_infos[iinfo].InputIsNA; ValueGetter getter; @@ -579,12 +265,12 @@ private ValueGetter ComposeGetterOne(IRow input, int iinfo) /// Getter generator for vector valued inputs. /// private ValueGetter> ComposeGetterVec(IRow input, int iinfo) - => Utils.MarshalInvoke(ComposeGetterVec, _infos[iinfo].TypeSrc.ItemType.RawType, input, iinfo); + => Utils.MarshalInvoke(ComposeGetterVec, _infos[iinfo].InputType.ItemType.RawType, input, iinfo); private ValueGetter> ComposeGetterVec(IRow input, int iinfo) { var getSrc = input.GetGetter>(ColMapNewToOld[iinfo]); - var isNA = (RefPredicate)_isNAs[iinfo]; + var isNA = (RefPredicate)_infos[iinfo].InputIsNA; var val = default(T); var defaultIsNA = isNA(ref val); var src = default(VBuffer); @@ -598,48 +284,209 @@ private ValueGetter> ComposeGetterVec(IRow input, int iinfo) // Sense indicates if the values added to the indices list represent NAs or non-NAs. bool sense; getSrc(ref src); - _parent.FindNAs(ref src, isNA, defaultIsNA, indices, out sense); - _parent.FillValues(src.Length, ref dst, indices, sense); + FindNAs(ref src, isNA, defaultIsNA, indices, out sense); + FillValues(src.Length, ref dst, indices, sense); }; } + + /// + /// Adds all NAs (or non-NAs) to the indices List. Whether NAs or non-NAs have been added is indicated by the bool sense. + /// + private void FindNAs(ref VBuffer src, RefPredicate isNA, bool defaultIsNA, List indices, out bool sense) + { + Host.AssertValue(isNA); + Host.AssertValue(indices); + + // Find the indices of all of the NAs. + indices.Clear(); + var srcValues = src.Values; + var srcCount = src.Count; + if (src.IsDense) + { + for (int i = 0; i < srcCount; i++) + { + if (isNA(ref srcValues[i])) + indices.Add(i); + } + sense = true; + } + else if (!defaultIsNA) + { + var srcIndices = src.Indices; + for (int ii = 0; ii < srcCount; ii++) + { + if (isNA(ref srcValues[ii])) + indices.Add(srcIndices[ii]); + } + sense = true; + } + else + { + // Note that this adds non-NAs to indices -- this is indicated by sense being false. + var srcIndices = src.Indices; + for (int ii = 0; ii < srcCount; ii++) + { + if (!isNA(ref srcValues[ii])) + indices.Add(srcIndices[ii]); + } + sense = false; + } + } + + /// + /// Fills indicator values for vectors. The indices is a list that either holds all of the NAs or all + /// of the non-NAs, indicated by sense being true or false respectively. + /// + private void FillValues(int srcLength, ref VBuffer dst, List indices, bool sense) + { + var dstValues = dst.Values; + var dstIndices = dst.Indices; + + if (indices.Count == 0) + { + if (sense) + { + // Return empty VBuffer. + dst = new VBuffer(srcLength, 0, dstValues, dstIndices); + return; + } + + // Return VBuffer filled with 1's. + Utils.EnsureSize(ref dstValues, srcLength, false); + for (int i = 0; i < srcLength; i++) + dstValues[i] = true; + dst = new VBuffer(srcLength, dstValues, dstIndices); + return; + } + + if (sense && indices.Count < srcLength / 2) + { + // Will produce sparse output. + int dstCount = indices.Count; + Utils.EnsureSize(ref dstValues, dstCount, false); + Utils.EnsureSize(ref dstIndices, dstCount, false); + + indices.CopyTo(dstIndices); + for (int ii = 0; ii < dstCount; ii++) + dstValues[ii] = true; + + Host.Assert(dstCount <= srcLength); + dst = new VBuffer(srcLength, dstCount, dstValues, dstIndices); + } + else if (!sense && srcLength - indices.Count < srcLength / 2) + { + // Will produce sparse output. + int dstCount = srcLength - indices.Count; + Utils.EnsureSize(ref dstValues, dstCount, false); + Utils.EnsureSize(ref dstIndices, dstCount, false); + + // Appends the length of the src to make the loop simpler, + // as the length of src will never be reached in the loop. + indices.Add(srcLength); + + int iiDst = 0; + int iiSrc = 0; + int iNext = indices[iiSrc]; + for (int i = 0; i < srcLength; i++) + { + Host.Assert(0 <= i && i <= iNext); + Host.Assert(iiSrc + iiDst == i); + if (i < iNext) + { + Host.Assert(iiDst < dstCount); + dstValues[iiDst] = true; + dstIndices[iiDst++] = i; + } + else + { + Host.Assert(iiSrc + 1 < indices.Count); + Host.Assert(iNext < indices[iiSrc + 1]); + iNext = indices[++iiSrc]; + } + } + Host.Assert(srcLength == iiSrc + iiDst); + Host.Assert(iiDst == dstCount); + + dst = new VBuffer(srcLength, dstCount, dstValues, dstIndices); + } + else + { + // Will produce dense output. + Utils.EnsureSize(ref dstValues, srcLength, false); + + // Appends the length of the src to make the loop simpler, + // as the length of src will never be reached in the loop. + indices.Add(srcLength); + + int ii = 0; + for (int i = 0; i < srcLength; i++) + { + Host.Assert(0 <= i && i <= indices[ii]); + if (i == indices[ii]) + { + dstValues[i] = sense; + ii++; + Host.Assert(ii < indices.Count); + Host.Assert(indices[ii - 1] < indices[ii]); + } + else + dstValues[i] = !sense; + } + + dst = new VBuffer(srcLength, dstValues, dstIndices); + } + } } } - public sealed class NAIndicatorEstimator : IEstimator + public sealed class NAIndicatorEstimator : TrivialEstimator { - private readonly IHost _host; - private readonly NAIndicatorTransform.ColumnInfo[] _columns; + private readonly (string input, string output)[] _columnPairs; - public NAIndicatorEstimator(IHostEnvironment env, string name, string source = null) - : this(env, new NAIndicatorTransform.ColumnInfo(source ?? name, name)) + /// + /// Initializes a new instance of + /// + /// The environment to use. + /// The names of the input columns of the transformation and the corresponding names for the output columns. + public NAIndicatorEstimator(IHostEnvironment env, params (string input, string output)[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(NAIndicatorTransform)), new NAIndicatorTransform(env, columns)) { + Contracts.CheckValue(env, nameof(env)); + _columnPairs = columns; } - public NAIndicatorEstimator(IHostEnvironment env, params NAIndicatorTransform.ColumnInfo[] columns) + /// + /// Initializes a new instance of + /// + /// The environment to use. + /// The name of the input column of the transformation. + /// The name of the column produced by the transformation. + public NAIndicatorEstimator(IHostEnvironment env, string input, string output = null) + : this(env, (input, output ?? input)) { - Contracts.CheckValue(env, nameof(env)); - _host = env.Register(nameof(NAIndicatorEstimator)); - _columns = columns; } - public SchemaShape GetOutputSchema(SchemaShape inputSchema) + /// + /// Returns the schema that would be produced by the transformation. + /// + public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { - _host.CheckValue(inputSchema, nameof(inputSchema)); + Host.CheckValue(inputSchema, nameof(inputSchema)); var result = inputSchema.Columns.ToDictionary(x => x.Name); - foreach (var colInfo in _columns) + foreach (var colPair in _columnPairs) { - if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); - string reason = NAIndicatorTransform.TestType(col.ItemType); - if (reason != null) - throw _host.ExceptParam(nameof(inputSchema), reason); + if (!inputSchema.TryFindColumn(colPair.input, out var col) || !Conversions.Instance.TryGetIsNAPredicate(col.ItemType, out Delegate del)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.input); + var metadata = new List(); + if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.SlotNames, out var slotMeta)) + metadata.Add(slotMeta); + if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.IsNormalized, out var normalized)) + metadata.Add(normalized); ColumnType type = !col.ItemType.IsVector ? (ColumnType) BoolType.Instance : new VectorType(BoolType.Instance, col.ItemType.AsVector); - result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, col.Kind, type, false, default); + result[colPair.output] = new SchemaShape.Column(colPair.output, col.Kind, type, false, new SchemaShape(metadata.ToArray())); } return new SchemaShape(result.Values); } - - public NAIndicatorTransform Fit(IDataView input) => new NAIndicatorTransform(_host, input, _columns); } /// @@ -697,46 +544,76 @@ public override IEstimator Reconcile(IHostEnvironment env, IReadOnlyDictionary outputNames, IReadOnlyCollection usedNames) { - var infos = new NAIndicatorTransform.ColumnInfo[toOutput.Length]; + var columnPairs = new (string input, string output)[toOutput.Length]; for (int i = 0; i < toOutput.Length; ++i) { var col = (IColInput)toOutput[i]; - infos[i] = new NAIndicatorTransform.ColumnInfo(inputNames[col.Input], outputNames[toOutput[i]]); + columnPairs[i] = (inputNames[col.Input], outputNames[toOutput[i]]); } - return new NAIndicatorEstimator(env, infos); + return new NAIndicatorEstimator(env, columnPairs); } } + /// + /// Produces a column of boolean entries indicating wheter input column entries were missing. + /// + /// The input column. + /// A column indicating wheter input column entries were missing. public static Scalar IsMissingValue(this Scalar input) { Contracts.CheckValue(input, nameof(input)); return new OutScalar(input); } + /// + /// Produces a column of boolean entries indicating wheter input column entries were missing. + /// + /// The input column. + /// A column indicating wheter input column entries were missing. public static Scalar IsMissingValue(this Scalar input) { Contracts.CheckValue(input, nameof(input)); return new OutScalar(input); } + /// + /// Produces a column of boolean entries indicating wheter input column entries were missing. + /// + /// The input column. + /// A column indicating wheter input column entries were missing. public static Vector IsMissingValue(this Vector input) { Contracts.CheckValue(input, nameof(input)); return new OutVectorColumn(input); } + /// + /// Produces a column of boolean entries indicating wheter input column entries were missing. + /// + /// The input column. + /// A column indicating wheter input column entries were missing. public static Vector IsMissingValue(this Vector input) { Contracts.CheckValue(input, nameof(input)); return new OutVectorColumn(input); } + /// + /// Produces a column of boolean entries indicating wheter input column entries were missing. + /// + /// The input column. + /// A column indicating wheter input column entries were missing. public static VarVector IsMissingValue(this VarVector input) { Contracts.CheckValue(input, nameof(input)); return new OutVarVectorColumn(input); } + /// + /// Produces a column of boolean entries indicating wheter input column entries were missing. + /// + /// The input column. + /// A column indicating wheter input column entries were missing. public static VarVector IsMissingValue(this VarVector input) { Contracts.CheckValue(input, nameof(input)); diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs index 5a5c2c76c7..0384b9da7a 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Data; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.Internal.Utilities; @@ -771,5 +772,51 @@ public void PrincipalComponentAnalysis() var type = schema.GetColumnType(pcaCol); Assert.True(type.IsVector && type.IsKnownSizeVector && type.ItemType.IsNumber); } + + + [Fact] + public void NAIndicatorStatic() + { + var Env = new ConsoleEnvironment(seed: 0); + + string dataPath = GetDataPath("breast-cancer.txt"); + var reader = TextLoader.CreateReader(Env, ctx => ( + ScalarFloat: ctx.LoadFloat(1), + ScalarDouble: ctx.LoadDouble(1), + VectorFloat: ctx.LoadFloat(1, 4), + VectorDoulbe: ctx.LoadDouble(1, 4) + )); + + var data = reader.Read(new MultiFileSource(dataPath)); + + var est = data.MakeNewEstimator(). + Append(row => ( + A: row.ScalarFloat.IsMissingValue(), + B: row.ScalarDouble.IsMissingValue(), + C: row.VectorFloat.IsMissingValue(), + D: row.VectorDoulbe.IsMissingValue() + )); + + IDataView newData = TakeFilter.Create(Env, est.Fit(data).Transform(data).AsDynamic, 4); + Assert.NotNull(newData); + bool[] ScalarFloat = newData.GetColumn(Env, "A").ToArray(); + bool[] ScalarDouble = newData.GetColumn(Env, "B").ToArray(); + bool[][] VectorFloat = newData.GetColumn(Env, "C").ToArray(); + bool[][] VectorDoulbe = newData.GetColumn(Env, "D").ToArray(); + + Assert.NotNull(ScalarFloat); + Assert.NotNull(ScalarDouble); + Assert.NotNull(VectorFloat); + Assert.NotNull(VectorDoulbe); + for (int i = 0; i < 4; i++) + { + Assert.True(!ScalarFloat[i] && !ScalarDouble[i]); + Assert.NotNull(VectorFloat[i]); + Assert.NotNull(VectorDoulbe[i]); + for (int j = 0; j < 4; j++) + Assert.True(!VectorFloat[i][j] && !VectorDoulbe[i][j]); + } + + } } } \ No newline at end of file diff --git a/test/Microsoft.ML.Tests/Transformers/NAIndicatorTests.cs b/test/Microsoft.ML.Tests/Transformers/NAIndicatorTests.cs index 98ab1ab93a..c838d8986d 100644 --- a/test/Microsoft.ML.Tests/Transformers/NAIndicatorTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/NAIndicatorTests.cs @@ -8,6 +8,7 @@ using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.RunTests; using Microsoft.ML.Runtime.Tools; +using Microsoft.ML.Transforms; using System.IO; using Xunit; using Xunit.Abstractions; @@ -43,51 +44,11 @@ public void NAIndicatorWorkout() var dataView = ComponentCreation.CreateDataView(Env, data); var pipe = new NAIndicatorEstimator(Env, - new NAIndicatorTransform.ColumnInfo("A", "NAA"), - new NAIndicatorTransform.ColumnInfo("B", "NAB"), - new NAIndicatorTransform.ColumnInfo("C", "NAC"), - new NAIndicatorTransform.ColumnInfo("D", "NAD")); + new (string input, string output)[] { ("A", "NAA"), ("B", "NAB"), ("C", "NAC"), ("D", "NAD") }); TestEstimatorCore(pipe, dataView); Done(); } - [Fact] - public void NAIndicatorStatic() - { - string dataPath = GetDataPath("breast-cancer.txt"); - var reader = TextLoader.CreateReader(Env, ctx => ( - ScalarFloat: ctx.LoadFloat(1), - ScalarDouble: ctx.LoadDouble(1), - VectorFloat: ctx.LoadFloat(1, 4), - VectorDoulbe: ctx.LoadDouble(1, 4) - )); - - var data = reader.Read(new MultiFileSource(dataPath)); - var wrongCollection = new[] { new TestClass() { A = 1, B = 3, C = new float[2] { 1, 2 }, D = new double[2] { 3, 4 } } }; - var invalidData = ComponentCreation.CreateDataView(Env, wrongCollection); - - var est = data.MakeNewEstimator(). - Append(row => ( - A: row.ScalarFloat.IsMissingValue(), - B: row.ScalarDouble.IsMissingValue(), - C: row.VectorFloat.IsMissingValue(), - D: row.VectorDoulbe.IsMissingValue() - )); - - TestEstimatorCore(est.AsDynamic, data.AsDynamic, invalidInput: invalidData); - var outputPath = GetOutputPath("NAIndicator", "featurized.tsv"); - using (var ch = Env.Start("save")) - { - var saver = new TextSaver(Env, new TextSaver.Arguments { Silent = true }); - IDataView savedData = TakeFilter.Create(Env, est.Fit(data).Transform(data).AsDynamic, 4); - using (var fs = File.Create(outputPath)) - DataSaverUtils.SaveDataView(ch, saver, savedData, fs, keepHidden: true); - } - - CheckEquality("NAIndicator", "featurized.tsv"); - Done(); - } - [Fact] public void TestCommandLine() { @@ -107,11 +68,7 @@ public void TestOldSavingAndLoading() var dataView = ComponentCreation.CreateDataView(Env, data); var pipe = new NAIndicatorEstimator(Env, - new NAIndicatorTransform.ColumnInfo("A", "NAA"), - new NAIndicatorTransform.ColumnInfo("B", "NAB"), - new NAIndicatorTransform.ColumnInfo("C", "NAC"), - new NAIndicatorTransform.ColumnInfo("D", "NAD")); - + new (string input, string output)[] { ("A", "NAA"), ("B", "NAB"), ("C", "NAC"), ("D", "NAD") }); var result = pipe.Fit(dataView).Transform(dataView); var resultRoles = new RoleMappedData(result); using (var ms = new MemoryStream()) From fe94d3055d537a051418a158d93aaebfacc1c50a Mon Sep 17 00:00:00 2001 From: Artidoro Pagnoni Date: Mon, 15 Oct 2018 18:19:57 -0700 Subject: [PATCH 08/13] fixed some review comments and added a test --- .../NAIndicatorTransform.cs | 31 +++++++++--------- .../Transformers/NAIndicatorTests.cs | 32 +++++++++++++++++++ 2 files changed, 47 insertions(+), 16 deletions(-) diff --git a/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs b/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs index 6d2dcd9f9d..a8aa7181b2 100644 --- a/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs +++ b/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs @@ -151,6 +151,9 @@ protected override IRowMapper MakeRowMapper(ISchema schema) private sealed class Mapper : MapperBase { + private readonly NAIndicatorTransform _parent; + private readonly ColInfo[] _infos; + private sealed class ColInfo { public readonly string Output; @@ -169,10 +172,6 @@ public ColInfo(string input, string output, ColumnType inType, ColumnType outTyp } } - // are we sure we need this? maybe not? - private readonly NAIndicatorTransform _parent; - private readonly ColInfo[] _infos; - public Mapper(NAIndicatorTransform parent, ISchema inputSchema) : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) { @@ -555,10 +554,10 @@ public override IEstimator Reconcile(IHostEnvironment env, } /// - /// Produces a column of boolean entries indicating wheter input column entries were missing. + /// Produces a column of boolean entries indicating whether input column entries were missing. /// /// The input column. - /// A column indicating wheter input column entries were missing. + /// A column indicating whether input column entries were missing. public static Scalar IsMissingValue(this Scalar input) { Contracts.CheckValue(input, nameof(input)); @@ -566,10 +565,10 @@ public static Scalar IsMissingValue(this Scalar input) } /// - /// Produces a column of boolean entries indicating wheter input column entries were missing. + /// Produces a column of boolean entries indicating whether input column entries were missing. /// /// The input column. - /// A column indicating wheter input column entries were missing. + /// A column indicating whether input column entries were missing. public static Scalar IsMissingValue(this Scalar input) { Contracts.CheckValue(input, nameof(input)); @@ -577,10 +576,10 @@ public static Scalar IsMissingValue(this Scalar input) } /// - /// Produces a column of boolean entries indicating wheter input column entries were missing. + /// Produces a column of boolean entries indicating whether input column entries were missing. /// /// The input column. - /// A column indicating wheter input column entries were missing. + /// A column indicating whether input column entries were missing. public static Vector IsMissingValue(this Vector input) { Contracts.CheckValue(input, nameof(input)); @@ -588,10 +587,10 @@ public static Vector IsMissingValue(this Vector input) } /// - /// Produces a column of boolean entries indicating wheter input column entries were missing. + /// Produces a column of boolean entries indicating whether input column entries were missing. /// /// The input column. - /// A column indicating wheter input column entries were missing. + /// A column indicating whether input column entries were missing. public static Vector IsMissingValue(this Vector input) { Contracts.CheckValue(input, nameof(input)); @@ -599,10 +598,10 @@ public static Vector IsMissingValue(this Vector input) } /// - /// Produces a column of boolean entries indicating wheter input column entries were missing. + /// Produces a column of boolean entries indicating whether input column entries were missing. /// /// The input column. - /// A column indicating wheter input column entries were missing. + /// A column indicating whether input column entries were missing. public static VarVector IsMissingValue(this VarVector input) { Contracts.CheckValue(input, nameof(input)); @@ -610,10 +609,10 @@ public static VarVector IsMissingValue(this VarVector input) } /// - /// Produces a column of boolean entries indicating wheter input column entries were missing. + /// Produces a column of boolean entries indicating whether input column entries were missing. /// /// The input column. - /// A column indicating wheter input column entries were missing. + /// A column indicating whether input column entries were missing. public static VarVector IsMissingValue(this VarVector input) { Contracts.CheckValue(input, nameof(input)); diff --git a/test/Microsoft.ML.Tests/Transformers/NAIndicatorTests.cs b/test/Microsoft.ML.Tests/Transformers/NAIndicatorTests.cs index c838d8986d..b0db716a09 100644 --- a/test/Microsoft.ML.Tests/Transformers/NAIndicatorTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/NAIndicatorTests.cs @@ -78,5 +78,37 @@ public void TestOldSavingAndLoading() var loadedView = ModelFileUtils.LoadTransforms(Env, dataView, ms); } } + + [Fact] + public void NAIndicatorFileOutput() + { + string dataPath = GetDataPath("breast-cancer.txt"); + var reader = TextLoader.CreateReader(Env, ctx => ( + ScalarFloat: ctx.LoadFloat(1), + ScalarDouble: ctx.LoadDouble(1), + VectorFloat: ctx.LoadFloat(1, 4), + VectorDoulbe: ctx.LoadDouble(1, 4) + )); + + var data = reader.Read(new MultiFileSource(dataPath)).AsDynamic; + var wrongCollection = new[] { new TestClass() { A = 1, B = 3, C = new float[2] { 1, 2 }, D = new double[2] { 3, 4 } } }; + var invalidData = ComponentCreation.CreateDataView(Env, wrongCollection); + var est = new NAIndicatorEstimator(Env, + new (string input, string output)[] { ("ScalarFloat", "A"), ("ScalarDouble", "B"), ("VectorFloat", "C"), ("VectorDoulbe", "D") }); + + TestEstimatorCore(est, data, invalidInput: invalidData); + var outputPath = GetOutputPath("NAIndicator", "featurized.tsv"); + using (var ch = Env.Start("save")) + { + var saver = new TextSaver(Env, new TextSaver.Arguments { Silent = true }); + IDataView savedData = TakeFilter.Create(Env, est.Fit(data).Transform(data), 4); + using (var fs = File.Create(outputPath)) + DataSaverUtils.SaveDataView(ch, saver, savedData, fs, keepHidden: true); + } + + CheckEquality("NAIndicator", "featurized.tsv"); + Done(); + } + } } From 122d48cdef9a0ad5c948dcf7cf3a235664ab720e Mon Sep 17 00:00:00 2001 From: Artidoro Pagnoni Date: Tue, 16 Oct 2018 14:16:59 -0700 Subject: [PATCH 09/13] fixed entrypointcatalog test --- test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv index c4434d23b7..22f929ecb4 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv +++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv @@ -112,7 +112,7 @@ Transforms.ManyHeterogeneousModelCombiner Combines a sequence of TransformModels Transforms.MeanVarianceNormalizer Normalizes the data based on the computed mean and variance of the data. Microsoft.ML.Runtime.Data.Normalize MeanVar Microsoft.ML.Runtime.Data.NormalizeTransform+MeanVarArguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput Transforms.MinMaxNormalizer Normalizes the data based on the observed minimum and maximum values of the data. Microsoft.ML.Runtime.Data.Normalize MinMax Microsoft.ML.Runtime.Data.NormalizeTransform+MinMaxArguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput Transforms.MissingValueHandler Handle missing values by replacing them with either the default value or the mean/min/max value (for non-text columns only). An indicator column can optionally be concatenated, if theinput column type is numeric. Microsoft.ML.Runtime.Data.NAHandling Handle Microsoft.ML.Runtime.Data.NAHandleTransform+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput -Transforms.MissingValueIndicator Create a boolean output column with the same number of slots as the input column, where the output value is true if the value in the input column is missing. Microsoft.ML.Runtime.Data.NAHandling Indicator Microsoft.ML.Runtime.Data.NAIndicatorTransform+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput +Transforms.MissingValueIndicator Create a boolean output column with the same number of slots as the input column, where the output value is true if the value in the input column is missing. Microsoft.ML.Runtime.Data.NAHandling Indicator Microsoft.ML.Transforms.NAIndicatorTransform+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput Transforms.MissingValuesDropper Removes NAs from vector columns. Microsoft.ML.Runtime.Data.NAHandling Drop Microsoft.ML.Runtime.Data.NADropTransform+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput Transforms.MissingValuesRowDropper Filters out rows that contain missing values. Microsoft.ML.Runtime.Data.NAHandling Filter Microsoft.ML.Runtime.Data.NAFilter+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput Transforms.MissingValueSubstitutor Create an output column of the same type and size of the input column, where missing values are replaced with either the default value or the mean/min/max value (for non-text columns only). Microsoft.ML.Runtime.Data.NAHandling Replace Microsoft.ML.Runtime.Data.NAReplaceTransform+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput From 28e9570de3f5c4ea2a99fe310acddc285ded08cd Mon Sep 17 00:00:00 2001 From: Artidoro Pagnoni Date: Tue, 16 Oct 2018 17:18:00 -0700 Subject: [PATCH 10/13] propagated metadata --- .../NAIndicatorTransform.cs | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs b/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs index a8aa7181b2..0ce7c2756c 100644 --- a/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs +++ b/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs @@ -139,15 +139,13 @@ internal static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, IS public override void Save(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); - ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); - SaveColumns(ctx); } protected override IRowMapper MakeRowMapper(ISchema schema) - => new Mapper(this, schema); + => new Mapper(this, Schema.Create(schema)); private sealed class Mapper : MapperBase { @@ -172,14 +170,14 @@ public ColInfo(string input, string output, ColumnType inType, ColumnType outTyp } } - public Mapper(NAIndicatorTransform parent, ISchema inputSchema) + public Mapper(NAIndicatorTransform parent, Schema inputSchema) : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) { _parent = parent; _infos = CreateInfos(inputSchema); } - private ColInfo[] CreateInfos(ISchema inputSchema) + private ColInfo[] CreateInfos(Schema inputSchema) { Host.AssertValue(inputSchema); var infos = new ColInfo[_parent.ColumnPairs.Length]; @@ -199,16 +197,21 @@ private ColInfo[] CreateInfos(ISchema inputSchema) return infos; } - public override RowMapperColumnInfo[] GetOutputColumns() + public override Schema.Column[] GetOutputColumns() { - var result = new RowMapperColumnInfo[_parent.ColumnPairs.Length]; + var result = new Schema.Column[_parent.ColumnPairs.Length]; for (int iinfo = 0; iinfo < _infos.Length; iinfo++) { InputSchema.TryGetColumnIndex(_infos[iinfo].Input, out int colIndex); Host.Assert(colIndex >= 0); - var colMetaInfo = new ColumnMetadataInfo(_infos[iinfo].Output); - var meta = RowColumnUtils.GetMetadataAsRow(InputSchema, colIndex, x => x == MetadataUtils.Kinds.SlotNames || x == MetadataUtils.Kinds.IsNormalized); - result[iinfo] = new RowMapperColumnInfo(_infos[iinfo].Output, _infos[iinfo].OutputType, meta); + var builder = new Schema.Metadata.Builder(); + builder.Add(InputSchema[colIndex].Metadata, x => x == MetadataUtils.Kinds.SlotNames); + ValueGetter getter = (ref bool dst) => + { + dst = true; + }; + builder.Add(new Schema.Column(MetadataUtils.Kinds.IsNormalized, BoolType.Instance, null), getter); + result[iinfo] = new Schema.Column(_infos[iinfo].Output, _infos[iinfo].OutputType, builder.GetMetadata()); } return result; } @@ -479,8 +482,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) var metadata = new List(); if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.SlotNames, out var slotMeta)) metadata.Add(slotMeta); - if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.IsNormalized, out var normalized)) - metadata.Add(normalized); + metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false)); ColumnType type = !col.ItemType.IsVector ? (ColumnType) BoolType.Instance : new VectorType(BoolType.Instance, col.ItemType.AsVector); result[colPair.output] = new SchemaShape.Column(colPair.output, col.Kind, type, false, new SchemaShape(metadata.ToArray())); } From f720a7562dfa12954e72856089bdc8fe6893e24a Mon Sep 17 00:00:00 2001 From: Artidoro Pagnoni Date: Wed, 17 Oct 2018 18:24:29 -0700 Subject: [PATCH 11/13] fixed review comments --- .../NAIndicatorTransform.cs | 35 +++++++----------- .../TrainerEstimators/LbfgsTests.cs | 4 +-- .../SymSgdClassificationTests.cs | 2 +- .../TrainerEstimators/TrainerEstimators.cs | 4 +-- .../Transformers/NAIndicatorTests.cs | 36 +++++++++++++++++++ 5 files changed, 54 insertions(+), 27 deletions(-) diff --git a/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs b/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs index 0ce7c2756c..15045d4015 100644 --- a/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs +++ b/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs @@ -19,16 +19,16 @@ using Microsoft.ML.Transforms; [assembly: LoadableClass(NAIndicatorTransform.Summary, typeof(IDataTransform), typeof(NAIndicatorTransform), typeof(NAIndicatorTransform.Arguments), typeof(SignatureDataTransform), - NAIndicatorTransform.FriendlyName, nameof(NAIndicatorTransform), "NAIndicator", NAIndicatorTransform.ShortName, DocName = "transform/NAHandle.md")] + NAIndicatorTransform.FriendlyName, NAIndicatorTransform.LoadName, "NAIndicator", NAIndicatorTransform.ShortName, DocName = "transform/NAHandle.md")] [assembly: LoadableClass(NAIndicatorTransform.Summary, typeof(IDataTransform), typeof(NAIndicatorTransform), null, typeof(SignatureLoadDataTransform), - NAIndicatorTransform.FriendlyName, nameof(NAIndicatorTransform))] + NAIndicatorTransform.FriendlyName, NAIndicatorTransform.LoadName)] [assembly: LoadableClass(NAIndicatorTransform.Summary, typeof(NAIndicatorTransform), null, typeof(SignatureLoadModel), - NAIndicatorTransform.FriendlyName, nameof(NAIndicatorTransform))] + NAIndicatorTransform.FriendlyName, NAIndicatorTransform.LoadName)] [assembly: LoadableClass(typeof(IRowMapper), typeof(NAIndicatorTransform), null, typeof(SignatureLoadRowMapper), - NAIndicatorTransform.FriendlyName, nameof(NAIndicatorTransform))] + NAIndicatorTransform.FriendlyName, NAIndicatorTransform.LoadName)] namespace Microsoft.ML.Transforms { @@ -60,6 +60,8 @@ public sealed class Arguments : TransformInputBase public Column[] Column; } + internal const string LoadName = "NaIndicatorTransform"; + private static VersionInfo GetVersionInfo() { return new VersionInfo( @@ -67,7 +69,7 @@ private static VersionInfo GetVersionInfo() verWrittenCur: 0x00010001, // Initial verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, - loaderSignature: nameof(NAIndicatorTransform), + loaderSignature: LoadName, loaderAssemblyName: typeof(NAIndicatorTransform).Assembly.FullName); } @@ -78,12 +80,14 @@ private static VersionInfo GetVersionInfo() private const string RegistrationName = nameof(NAIndicatorTransform); + internal (string input, string output)[] GetColumnPairs() => ColumnPairs; + /// /// Initializes a new instance of /// /// The environment to use. /// The names of the input columns of the transformation and the corresponding names for the output columns. - internal NAIndicatorTransform(IHostEnvironment env, params (string input, string output)[] columns) + public NAIndicatorTransform(IHostEnvironment env, params (string input, string output)[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(NAIndicatorTransform)), columns) { } @@ -100,17 +104,7 @@ private NAIndicatorTransform(IHostEnvironment env, ModelLoadContext ctx) } private static (string input, string output)[] GetColumnPairs(Column[] columns) - { - var cols = new (string input, string output)[columns.Length]; - for (int i = 0; i < cols.Length; i++) - { - var item = columns[i]; - - cols[i].input = item.Source; - cols[i].output = item.Name; - }; - return cols; - } + => columns.Select(c => (c.Source ?? c.Name, c.Name)).ToArray(); // Factory method for SignatureLoadModel internal static NAIndicatorTransform Create(IHostEnvironment env, ModelLoadContext ctx) @@ -166,7 +160,7 @@ public ColInfo(string input, string output, ColumnType inType, ColumnType outTyp Output = output; InputType = inType; OutputType = outType; - InputIsNA = GetIsNADelegate(InputType); ; + InputIsNA = GetIsNADelegate(InputType); } } @@ -443,8 +437,6 @@ private void FillValues(int srcLength, ref VBuffer dst, List indices, public sealed class NAIndicatorEstimator : TrivialEstimator { - private readonly (string input, string output)[] _columnPairs; - /// /// Initializes a new instance of /// @@ -454,7 +446,6 @@ public NAIndicatorEstimator(IHostEnvironment env, params (string input, string o : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(NAIndicatorTransform)), new NAIndicatorTransform(env, columns)) { Contracts.CheckValue(env, nameof(env)); - _columnPairs = columns; } /// @@ -475,7 +466,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); var result = inputSchema.Columns.ToDictionary(x => x.Name); - foreach (var colPair in _columnPairs) + foreach (var colPair in Transformer.GetColumnPairs()) { if (!inputSchema.TryFindColumn(colPair.input, out var col) || !Conversions.Instance.TryGetIsNAPredicate(col.ItemType, out Delegate del)) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.input); diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/LbfgsTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/LbfgsTests.cs index fb54ad52ce..ecbd4e4363 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/LbfgsTests.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/LbfgsTests.cs @@ -15,7 +15,7 @@ public partial class TrainerEstimators public void TestEstimatorLogisticRegression() { (IEstimator pipe, IDataView dataView) = GetBinaryClassificationPipeline(); - pipe.Append(new LogisticRegression(Env, "Features", "Label")); + pipe = pipe.Append(new LogisticRegression(Env, "Features", "Label")); TestEstimatorCore(pipe, dataView); Done(); } @@ -24,7 +24,7 @@ public void TestEstimatorLogisticRegression() public void TestEstimatorMulticlassLogisticRegression() { (IEstimator pipe, IDataView dataView) = GetMultiClassPipeline(); - pipe.Append(new MulticlassLogisticRegression(Env, "Features", "Label")); + pipe = pipe.Append(new MulticlassLogisticRegression(Env, "Features", "Label")); TestEstimatorCore(pipe, dataView); Done(); } diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/SymSgdClassificationTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/SymSgdClassificationTests.cs index 68e126c439..a8655b9aac 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/SymSgdClassificationTests.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/SymSgdClassificationTests.cs @@ -22,7 +22,7 @@ public partial class TrainerEstimators public void TestEstimatorSymSgdClassificationTrainer() { (var pipe, var dataView) = GetBinaryClassificationPipeline(); - pipe.Append(new SymSgdClassificationTrainer(Env, "Features", "Label")); + pipe = pipe.Append(new SymSgdClassificationTrainer(Env, "Features", "Label")); TestEstimatorCore(pipe, dataView); Done(); } diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs index 538fc80a88..15e0de7c0f 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs @@ -84,7 +84,7 @@ public void KMeansEstimator() public void TestEstimatorHogwildSGD() { (IEstimator pipe, IDataView dataView) = GetBinaryClassificationPipeline(); - pipe.Append(new StochasticGradientDescentClassificationTrainer(Env, "Features", "Label")); + pipe = pipe.Append(new StochasticGradientDescentClassificationTrainer(Env, "Features", "Label")); TestEstimatorCore(pipe, dataView); Done(); } @@ -96,7 +96,7 @@ public void TestEstimatorHogwildSGD() public void TestEstimatorMultiClassNaiveBayesTrainer() { (IEstimator pipe, IDataView dataView) = GetMultiClassPipeline(); - pipe.Append(new MultiClassNaiveBayesTrainer(Env, "Features", "Label")); + pipe = pipe.Append(new MultiClassNaiveBayesTrainer(Env, "Features", "Label")); TestEstimatorCore(pipe, dataView); Done(); } diff --git a/test/Microsoft.ML.Tests/Transformers/NAIndicatorTests.cs b/test/Microsoft.ML.Tests/Transformers/NAIndicatorTests.cs index b0db716a09..7d04318ebb 100644 --- a/test/Microsoft.ML.Tests/Transformers/NAIndicatorTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/NAIndicatorTests.cs @@ -9,6 +9,7 @@ using Microsoft.ML.Runtime.RunTests; using Microsoft.ML.Runtime.Tools; using Microsoft.ML.Transforms; +using System; using System.IO; using Xunit; using Xunit.Abstractions; @@ -110,5 +111,40 @@ public void NAIndicatorFileOutput() Done(); } + [Fact] + public void NAIndicatorMetadataTest() + { + var data = new[] { + new TestClass() { A = 1, B = 3, C = new float[2]{ 1, 2 } , D = new double[2]{ 3,4} }, + new TestClass() { A = float.NaN, B = double.NaN, C = new float[2]{ float.NaN, float.NaN } , D = new double[2]{ double.NaN,double.NaN}}, + new TestClass() { A = float.NegativeInfinity, B = double.NegativeInfinity, C = new float[2]{ float.NegativeInfinity, float.NegativeInfinity } , D = new double[2]{ double.NegativeInfinity, double.NegativeInfinity}}, + new TestClass() { A = float.PositiveInfinity, B = double.PositiveInfinity, C = new float[2]{ float.PositiveInfinity, float.PositiveInfinity, } , D = new double[2]{ double.PositiveInfinity, double.PositiveInfinity}}, + new TestClass() { A = 2, B = 1, C = new float[2]{ 3, 4 } , D = new double[2]{ 5,6}}, + }; + + var dataView = ComponentCreation.CreateDataView(Env, data); + var args = new NAIndicatorTransform.Arguments(); + args.Column = new NAIndicatorTransform.Column[] { new NAIndicatorTransform.Column() { Name = "NAA", Source = "A" } }; + var pipe = new CategoricalEstimator(Env, new CategoricalEstimator.ColumnInfo("A", "CatA")); + var newpipe = pipe.Append(new NAIndicatorEstimator(Env, new (string input, string output)[] { ("CatA", "NAA") })); + var result = newpipe.Fit(dataView).Transform(dataView); + Assert.True(result.Schema.TryGetColumnIndex("NAA", out var col)); + // Check that the column is normalized. + Assert.True(result.Schema.IsNormalized(col)); + // Check that slot names metadata was correctly created. + var value = new VBuffer>(); + var type = result.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, col); + result.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, col, ref value); + Assert.True(value.Length == 4); + var mem = new ReadOnlyMemory(); + value.GetItemOrDefault(0, ref mem); + Assert.True(mem.ToString() == "1"); + value.GetItemOrDefault(1, ref mem); + Assert.True(mem.ToString() == "-Infinity"); + value.GetItemOrDefault(2, ref mem); + Assert.True(mem.ToString() == "Infinity"); + value.GetItemOrDefault(3, ref mem); + Assert.True(mem.ToString() == "2"); + } } } From 5d4b6bc4587ff99c78e216bee5ee2a8711958200 Mon Sep 17 00:00:00 2001 From: Artidoro Pagnoni Date: Thu, 18 Oct 2018 13:59:31 -0700 Subject: [PATCH 12/13] fixed review comments --- src/Microsoft.ML.Transforms/NAIndicatorTransform.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs b/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs index 15045d4015..cb9cc6bfbb 100644 --- a/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs +++ b/src/Microsoft.ML.Transforms/NAIndicatorTransform.cs @@ -80,7 +80,7 @@ private static VersionInfo GetVersionInfo() private const string RegistrationName = nameof(NAIndicatorTransform); - internal (string input, string output)[] GetColumnPairs() => ColumnPairs; + public IReadOnlyList<(string input, string output)> Columns => ColumnPairs.AsReadOnly(); /// /// Initializes a new instance of @@ -466,7 +466,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); var result = inputSchema.Columns.ToDictionary(x => x.Name); - foreach (var colPair in Transformer.GetColumnPairs()) + foreach (var colPair in Transformer.Columns) { if (!inputSchema.TryFindColumn(colPair.input, out var col) || !Conversions.Instance.TryGetIsNAPredicate(col.ItemType, out Delegate del)) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.input); From 0f4507f11e1e013588375d51d9ae0fa6b405a481 Mon Sep 17 00:00:00 2001 From: Artidoro Pagnoni Date: Fri, 19 Oct 2018 16:26:12 -0700 Subject: [PATCH 13/13] removed unused object --- test/Microsoft.ML.Tests/Transformers/NAIndicatorTests.cs | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/Microsoft.ML.Tests/Transformers/NAIndicatorTests.cs b/test/Microsoft.ML.Tests/Transformers/NAIndicatorTests.cs index 7d04318ebb..505f39ebf7 100644 --- a/test/Microsoft.ML.Tests/Transformers/NAIndicatorTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/NAIndicatorTests.cs @@ -123,8 +123,6 @@ public void NAIndicatorMetadataTest() }; var dataView = ComponentCreation.CreateDataView(Env, data); - var args = new NAIndicatorTransform.Arguments(); - args.Column = new NAIndicatorTransform.Column[] { new NAIndicatorTransform.Column() { Name = "NAA", Source = "A" } }; var pipe = new CategoricalEstimator(Env, new CategoricalEstimator.ColumnInfo("A", "CatA")); var newpipe = pipe.Append(new NAIndicatorEstimator(Env, new (string input, string output)[] { ("CatA", "NAA") })); var result = newpipe.Fit(dataView).Transform(dataView);