diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs index babca545c8..fd94bc1739 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs @@ -40,6 +40,23 @@ public sealed partial class TextLoader : IDataLoader /// public sealed class Column { + public Column() { } + + public Column(string name, DataKind? type, int index) + : this(name, type, new[] { new Range(index) }) { } + + public Column(string name, DataKind? type, Range[] source, KeyRange keyRange = null) + { + Contracts.CheckValue(name, nameof(name)); + Contracts.CheckValue(source, nameof(source)); + Contracts.CheckValueOrNull(keyRange); + + Name = name; + Type = type; + Source = source; + KeyRange = keyRange; + } + [Argument(ArgumentType.AtMostOnce, HelpText = "Name of the column")] public string Name; @@ -179,6 +196,20 @@ public bool IsValid() public sealed class Range { + public Range() { } + + public Range(int index) + : this(index, index) { } + + public Range(int min, int max) + { + Contracts.CheckParam(min >= 0, nameof(min), "min must be non-negative."); + Contracts.CheckParam(max >= min, nameof(max), "max must be greater than or equal to min."); + + Min = min; + Max = max; + } + [Argument(ArgumentType.Required, HelpText = "First index in the range")] public int Min; diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs index e36f8545b0..48f3f9ddc3 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs @@ -491,13 +491,13 @@ private TextLoader.Column GetColumn(string name, ColumnType type, int? start) { var key = type.ItemType.AsKey; if (!key.Contiguous) - keyRange = new KeyRange() { Min = key.Min, Contiguous = false }; + keyRange = new KeyRange(key.Min, contiguous: false); else if (key.Count == 0) - keyRange = new KeyRange() { Min = key.Min }; + keyRange = new KeyRange(key.Min); else { Contracts.Assert(key.Count >= 1); - keyRange = new KeyRange() { Min = key.Min, Max = key.Min + (ulong)(key.Count - 1) }; + keyRange = new KeyRange(key.Min, key.Min + (ulong)(key.Count - 1)); } kind = key.RawKind; } diff --git a/src/Microsoft.ML.Data/Transforms/TermTransform.cs b/src/Microsoft.ML.Data/Transforms/TermTransform.cs index 6365e27a54..bd82dc588a 100644 --- a/src/Microsoft.ML.Data/Transforms/TermTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/TermTransform.cs @@ -352,15 +352,7 @@ private static TermMap CreateFileTermMap(IHostEnvironment env, IChannel ch, Argu new TextLoader.Arguments() { Separator = "tab", - Column = new[] - { - new TextLoader.Column() - { - Name ="Term", - Type = DataKind.TX, - Source = new[] { new TextLoader.Range() { Min = 0 } } - } - } + Column = new[] { new TextLoader.Column("Term", DataKind.TX, 0) } }, fileSource); src = "Term"; diff --git a/src/Microsoft.ML.Data/Utilities/TypeParsingUtils.cs b/src/Microsoft.ML.Data/Utilities/TypeParsingUtils.cs index a7453f07de..a24d7d3883 100644 --- a/src/Microsoft.ML.Data/Utilities/TypeParsingUtils.cs +++ b/src/Microsoft.ML.Data/Utilities/TypeParsingUtils.cs @@ -85,6 +85,15 @@ public static KeyType ConstructKeyType(DataKind? type, KeyRange range) /// public sealed class KeyRange { + public KeyRange() { } + + public KeyRange(ulong min, ulong? max = null, bool contiguous = true) + { + Min = min; + Max = max; + Contiguous = contiguous; + } + [Argument(ArgumentType.AtMostOnce, HelpText = "First index in the range")] public ulong Min; diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index c44678bf07..403b8d07cd 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -43,19 +43,9 @@ private IDataView GetBreastCancerDataView() { Column = new[] { - new TextLoader.Column() - { - Name = "Label", - Source = new [] { new TextLoader.Range() { Min = 0, Max = 0} }, - Type = Runtime.Data.DataKind.R4 - }, - - new TextLoader.Column() - { - Name = "Features", - Source = new [] { new TextLoader.Range() { Min = 1, Max = 9} }, - Type = Runtime.Data.DataKind.R4 - } + new TextLoader.Column("Label", DataKind.R4, 0), + new TextLoader.Column("Features", DataKind.R4, + new [] { new TextLoader.Range(1, 9) }) } }, @@ -74,31 +64,10 @@ private IDataView GetBreastCancerDataviewWithTextColumns() HasHeader = true, Column = new[] { - new TextLoader.Column() - { - Name = "Label", - Source = new [] { new TextLoader.Range() { Min = 0, Max = 0} } - }, - - new TextLoader.Column() - { - Name = "F1", - Source = new [] { new TextLoader.Range() { Min = 1, Max = 1} }, - Type = Runtime.Data.DataKind.Text - }, - - new TextLoader.Column() - { - Name = "F2", - Source = new [] { new TextLoader.Range() { Min = 2, Max = 2} }, - Type = Runtime.Data.DataKind.I4 - }, - - new TextLoader.Column() - { - Name = "Rest", - Source = new [] { new TextLoader.Range() { Min = 3, Max = 9} } - } + new TextLoader.Column("Label", type: null, 0), + new TextLoader.Column("F1", DataKind.Text, 1), + new TextLoader.Column("F2", DataKind.I4, 2), + new TextLoader.Column("Rest", type: null, new [] { new TextLoader.Range(3, 9) }) } }, @@ -998,19 +967,8 @@ public void EntryPointPipelineEnsembleText() HasHeader = true, Column = new[] { - new TextLoader.Column() - { - Name = "Label", - Source = new [] { new TextLoader.Range() { Min = 0, Max = 0} }, - Type = Runtime.Data.DataKind.TX - }, - - new TextLoader.Column() - { - Name = "Text", - Source = new [] { new TextLoader.Range() { Min = 3, Max = 3} }, - Type = Runtime.Data.DataKind.TX - } + new TextLoader.Column("Label", DataKind.TX, 0), + new TextLoader.Column("Text", DataKind.TX, 3) } }, @@ -1222,19 +1180,8 @@ public void EntryPointMulticlassPipelineEnsemble() { Column = new[] { - new TextLoader.Column() - { - Name = "Label", - Source = new [] { new TextLoader.Range() { Min = 0, Max = 0} }, - Type = Runtime.Data.DataKind.R4 - }, - - new TextLoader.Column() - { - Name = "Features", - Source = new [] { new TextLoader.Range() { Min = 1, Max = 4} }, - Type = Runtime.Data.DataKind.R4 - } + new TextLoader.Column("Label", DataKind.R4, 0), + new TextLoader.Column("Features", DataKind.R4, new [] { new TextLoader.Range(1, 4) }) } }, @@ -3474,18 +3421,8 @@ public void EntryPointLinearPredictorSummary() HasHeader = true, Column = new[] { - new TextLoader.Column() - { - Name = "Label", - Source = new [] { new TextLoader.Range() { Min = 0, Max = 0} }, - }, - - new TextLoader.Column() - { - Name = "Features", - Source = new [] { new TextLoader.Range() { Min = 1, Max = 9} }, - Type = Runtime.Data.DataKind.Num - } + new TextLoader.Column("Label", type: null, 0), + new TextLoader.Column("Features", DataKind.Num, new [] { new TextLoader.Range(1, 9) }) } }, @@ -3561,12 +3498,7 @@ public void EntryPointPcaPredictorSummary() HasHeader = false, Column = new[] { - new TextLoader.Column() - { - Name = "Features", - Source = new [] { new TextLoader.Range() { Min = 1, Max = 784} }, - Type = Runtime.Data.DataKind.R4 - } + new TextLoader.Column("Features", DataKind.R4, new [] { new TextLoader.Range(1, 784) }) } }, @@ -3774,12 +3706,8 @@ public void EntryPointWordEmbeddings() SeparatorChars = new []{' '}, Column = new[] { - new TextLoader.Column() - { - Name = "Text", - Source = new [] { new TextLoader.Range() { Min = 0, VariableEnd=true, ForceVector=true} }, - Type = DataKind.Text - } + new TextLoader.Column("Text", DataKind.Text, + new [] { new TextLoader.Range() { Min = 0, VariableEnd=true, ForceVector=true} }) } }, InputFile = inputFile, diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/SimpleTrainAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/SimpleTrainAndPredict.cs index 246281777c..b05135692a 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/SimpleTrainAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/SimpleTrainAndPredict.cs @@ -82,44 +82,18 @@ private static TextTransform.Arguments MakeSentimentTextTransformArgs(bool norma private static TextLoader.Arguments MakeIrisTextLoaderArgs() { - return new TextLoader.Arguments() { Separator = "comma", HasHeader = true, Column = new[] - { - new TextLoader.Column() - { - Name = "SepalLength", - Source = new [] { new TextLoader.Range() { Min=0, Max=0} }, - Type = DataKind.R4 - }, - new TextLoader.Column() - { - Name = "SepalWidth", - Source = new [] { new TextLoader.Range() { Min=1, Max=1} }, - Type = DataKind.R4 - }, - new TextLoader.Column() - { - Name = "PetalLength", - Source = new [] { new TextLoader.Range() { Min=2, Max=2} }, - Type = DataKind.R4 - }, - new TextLoader.Column() - { - Name = "PetalWidth", - Source = new [] { new TextLoader.Range() { Min=3, Max=3} }, - Type = DataKind.R4 - }, - new TextLoader.Column() - { - Name = "Label", - Source = new [] { new TextLoader.Range() { Min=4, Max=4} }, - Type = DataKind.Text - } - } + { + new TextLoader.Column("SepalLength", DataKind.R4, 0), + new TextLoader.Column("SepalWidth", DataKind.R4, 1), + new TextLoader.Column("PetalLength", DataKind.R4, 2), + new TextLoader.Column("PetalWidth",DataKind.R4, 3), + new TextLoader.Column("Label", DataKind.Text, 4) + } }; } private static TextLoader.Arguments MakeSentimentTextLoaderArgs() @@ -129,21 +103,10 @@ private static TextLoader.Arguments MakeSentimentTextLoaderArgs() Separator = "tab", HasHeader = true, Column = new[] - { - new TextLoader.Column() - { - Name = "Label", - Source = new [] { new TextLoader.Range() { Min=0, Max=0} }, - Type = DataKind.BL - }, - - new TextLoader.Column() - { - Name = "SentimentText", - Source = new [] { new TextLoader.Range() { Min=1, Max=1} }, - Type = DataKind.Text - } - } + { + new TextLoader.Column("Label", DataKind.BL, 0), + new TextLoader.Column("SentimentText", DataKind.Text, 1) + } }; } } diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs index 0535f2b15d..b6333b3055 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs @@ -29,37 +29,13 @@ public void TrainAndPredictIrisModelUsingDirectInstantiationTest() new TextLoader.Arguments() { HasHeader = false, - Column = new[] { - new TextLoader.Column() - { - Name = "Label", - Source = new [] { new TextLoader.Range() { Min = 0, Max = 0} }, - Type = DataKind.R4 - }, - new TextLoader.Column() - { - Name = "SepalLength", - Source = new [] { new TextLoader.Range() { Min = 1, Max = 1} }, - Type = DataKind.R4 - }, - new TextLoader.Column() - { - Name = "SepalWidth", - Source = new [] { new TextLoader.Range() { Min = 2, Max = 2} }, - Type = DataKind.R4 - }, - new TextLoader.Column() - { - Name = "PetalLength", - Source = new [] { new TextLoader.Range() { Min = 3, Max = 3} }, - Type = DataKind.R4 - }, - new TextLoader.Column() - { - Name = "PetalWidth", - Source = new [] { new TextLoader.Range() { Min = 4, Max = 4} }, - Type = DataKind.R4 - } + Column = new[] + { + new TextLoader.Column("Label", DataKind.R4, 0), + new TextLoader.Column("SepalLength", DataKind.R4, 1), + new TextLoader.Column("SepalWidth", DataKind.R4, 2), + new TextLoader.Column("PetalLength", DataKind.R4, 3), + new TextLoader.Column("PetalWidth", DataKind.R4, 4) } }, new MultiFileSource(dataPath)); diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs index 7dd7a0145c..e42de17090 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs @@ -36,19 +36,8 @@ public void TrainAndPredictSentimentModelWithDirectionInstantiationTest() HasHeader = true, Column = new[] { - new TextLoader.Column() - { - Name = "Label", - Source = new [] { new TextLoader.Range() { Min=0, Max=0} }, - Type = DataKind.Num - }, - - new TextLoader.Column() - { - Name = "SentimentText", - Source = new [] { new TextLoader.Range() { Min=1, Max=1} }, - Type = DataKind.Text - } + new TextLoader.Column("Label", DataKind.Num, 0), + new TextLoader.Column("SentimentText", DataKind.Text, 1) } }, new MultiFileSource(dataPath)); @@ -116,19 +105,8 @@ public void TrainAndPredictSentimentModelWithDirectionInstantiationTestWithWordE HasHeader = true, Column = new[] { - new TextLoader.Column() - { - Name = "Label", - Source = new [] { new TextLoader.Range() { Min=0, Max=0} }, - Type = DataKind.Num - }, - - new TextLoader.Column() - { - Name = "SentimentText", - Source = new [] { new TextLoader.Range() { Min=1, Max=1} }, - Type = DataKind.Text - } + new TextLoader.Column("Label", DataKind.Num, 0), + new TextLoader.Column("SentimentText", DataKind.Text, 1) } }, new MultiFileSource(dataPath));