Skip to content

Add convenience constructors for TextLoader. #698

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 22, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs
Original file line number Diff line number Diff line change
@@ -40,6 +40,23 @@ public sealed partial class TextLoader : IDataLoader
/// </example>
public sealed class Column
{
public Column() { }

public Column(string name, DataKind? type, int index)
: this(name, type, new[] { new Range(index) }) { }

public Column(string name, DataKind? type, Range[] source, KeyRange keyRange = null)
{
Contracts.CheckValue(name, nameof(name));
Contracts.CheckValue(source, nameof(source));
Contracts.CheckValueOrNull(keyRange);

Name = name;
Type = type;
Source = source;
KeyRange = keyRange;
}

[Argument(ArgumentType.AtMostOnce, HelpText = "Name of the column")]
public string Name;

@@ -179,6 +196,20 @@ public bool IsValid()

public sealed class Range
{
public Range() { }

public Range(int index)
Copy link
Contributor

@Zruty0 Zruty0 Aug 20, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Range [](start = 19, length = 5)

Does it make sense to add a ctor that takes min and max? #Closed

: this(index, index) { }

public Range(int min, int max)
{
Contracts.CheckParam(min >= 0, nameof(min), "min must be non-negative.");
Contracts.CheckParam(max >= min, nameof(max), "max must be greater than or equal to min.");

Min = min;
Max = max;
}

[Argument(ArgumentType.Required, HelpText = "First index in the range")]
public int Min;

6 changes: 3 additions & 3 deletions src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs
Original file line number Diff line number Diff line change
@@ -491,13 +491,13 @@ private TextLoader.Column GetColumn(string name, ColumnType type, int? start)
{
var key = type.ItemType.AsKey;
if (!key.Contiguous)
keyRange = new KeyRange() { Min = key.Min, Contiguous = false };
keyRange = new KeyRange(key.Min, contiguous: false);
else if (key.Count == 0)
keyRange = new KeyRange() { Min = key.Min };
keyRange = new KeyRange(key.Min);
else
{
Contracts.Assert(key.Count >= 1);
keyRange = new KeyRange() { Min = key.Min, Max = key.Min + (ulong)(key.Count - 1) };
keyRange = new KeyRange(key.Min, key.Min + (ulong)(key.Count - 1));
}
kind = key.RawKind;
}
10 changes: 1 addition & 9 deletions src/Microsoft.ML.Data/Transforms/TermTransform.cs
Original file line number Diff line number Diff line change
@@ -352,15 +352,7 @@ private static TermMap CreateFileTermMap(IHostEnvironment env, IChannel ch, Argu
new TextLoader.Arguments()
{
Separator = "tab",
Column = new[]
{
new TextLoader.Column()
{
Name ="Term",
Type = DataKind.TX,
Source = new[] { new TextLoader.Range() { Min = 0 } }
}
}
Column = new[] { new TextLoader.Column("Term", DataKind.TX, 0) }
},
fileSource);
src = "Term";
9 changes: 9 additions & 0 deletions src/Microsoft.ML.Data/Utilities/TypeParsingUtils.cs
Original file line number Diff line number Diff line change
@@ -85,6 +85,15 @@ public static KeyType ConstructKeyType(DataKind? type, KeyRange range)
/// </summary>
public sealed class KeyRange
{
public KeyRange() { }

public KeyRange(ulong min, ulong? max = null, bool contiguous = true)
{
Min = min;
Max = max;
Contiguous = contiguous;
}

[Argument(ArgumentType.AtMostOnce, HelpText = "First index in the range")]
public ulong Min;

104 changes: 16 additions & 88 deletions test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
Original file line number Diff line number Diff line change
@@ -43,19 +43,9 @@ private IDataView GetBreastCancerDataView()
{
Column = new[]
{
new TextLoader.Column()
{
Name = "Label",
Source = new [] { new TextLoader.Range() { Min = 0, Max = 0} },
Type = Runtime.Data.DataKind.R4
},

new TextLoader.Column()
{
Name = "Features",
Source = new [] { new TextLoader.Range() { Min = 1, Max = 9} },
Type = Runtime.Data.DataKind.R4
}
new TextLoader.Column("Label", DataKind.R4, 0),
new TextLoader.Column("Features", DataKind.R4,
new [] { new TextLoader.Range(1, 9) })
}
},

@@ -74,31 +64,10 @@ private IDataView GetBreastCancerDataviewWithTextColumns()
HasHeader = true,
Column = new[]
{
new TextLoader.Column()
{
Name = "Label",
Source = new [] { new TextLoader.Range() { Min = 0, Max = 0} }
},

new TextLoader.Column()
{
Name = "F1",
Source = new [] { new TextLoader.Range() { Min = 1, Max = 1} },
Type = Runtime.Data.DataKind.Text
},

new TextLoader.Column()
{
Name = "F2",
Source = new [] { new TextLoader.Range() { Min = 2, Max = 2} },
Type = Runtime.Data.DataKind.I4
},

new TextLoader.Column()
{
Name = "Rest",
Source = new [] { new TextLoader.Range() { Min = 3, Max = 9} }
}
new TextLoader.Column("Label", type: null, 0),
new TextLoader.Column("F1", DataKind.Text, 1),
new TextLoader.Column("F2", DataKind.I4, 2),
new TextLoader.Column("Rest", type: null, new [] { new TextLoader.Range(3, 9) })
}
},

@@ -998,19 +967,8 @@ public void EntryPointPipelineEnsembleText()
HasHeader = true,
Column = new[]
{
new TextLoader.Column()
{
Name = "Label",
Source = new [] { new TextLoader.Range() { Min = 0, Max = 0} },
Type = Runtime.Data.DataKind.TX
},

new TextLoader.Column()
{
Name = "Text",
Source = new [] { new TextLoader.Range() { Min = 3, Max = 3} },
Type = Runtime.Data.DataKind.TX
}
new TextLoader.Column("Label", DataKind.TX, 0),
new TextLoader.Column("Text", DataKind.TX, 3)
}
},

@@ -1222,19 +1180,8 @@ public void EntryPointMulticlassPipelineEnsemble()
{
Column = new[]
{
new TextLoader.Column()
{
Name = "Label",
Source = new [] { new TextLoader.Range() { Min = 0, Max = 0} },
Type = Runtime.Data.DataKind.R4
},

new TextLoader.Column()
{
Name = "Features",
Source = new [] { new TextLoader.Range() { Min = 1, Max = 4} },
Type = Runtime.Data.DataKind.R4
}
new TextLoader.Column("Label", DataKind.R4, 0),
new TextLoader.Column("Features", DataKind.R4, new [] { new TextLoader.Range(1, 4) })
}
},

@@ -3474,18 +3421,8 @@ public void EntryPointLinearPredictorSummary()
HasHeader = true,
Column = new[]
{
new TextLoader.Column()
{
Name = "Label",
Source = new [] { new TextLoader.Range() { Min = 0, Max = 0} },
},

new TextLoader.Column()
{
Name = "Features",
Source = new [] { new TextLoader.Range() { Min = 1, Max = 9} },
Type = Runtime.Data.DataKind.Num
}
new TextLoader.Column("Label", type: null, 0),
new TextLoader.Column("Features", DataKind.Num, new [] { new TextLoader.Range(1, 9) })
}
},

@@ -3561,12 +3498,7 @@ public void EntryPointPcaPredictorSummary()
HasHeader = false,
Column = new[]
{
new TextLoader.Column()
{
Name = "Features",
Source = new [] { new TextLoader.Range() { Min = 1, Max = 784} },
Type = Runtime.Data.DataKind.R4
}
new TextLoader.Column("Features", DataKind.R4, new [] { new TextLoader.Range(1, 784) })
}
},

@@ -3774,12 +3706,8 @@ public void EntryPointWordEmbeddings()
SeparatorChars = new []{' '},
Column = new[]
{
new TextLoader.Column()
{
Name = "Text",
Source = new [] { new TextLoader.Range() { Min = 0, VariableEnd=true, ForceVector=true} },
Type = DataKind.Text
}
new TextLoader.Column("Text", DataKind.Text,
new [] { new TextLoader.Range() { Min = 0, VariableEnd=true, ForceVector=true} })
}
},
InputFile = inputFile,
59 changes: 11 additions & 48 deletions test/Microsoft.ML.Tests/Scenarios/Api/SimpleTrainAndPredict.cs
Original file line number Diff line number Diff line change
@@ -82,44 +82,18 @@ private static TextTransform.Arguments MakeSentimentTextTransformArgs(bool norma

private static TextLoader.Arguments MakeIrisTextLoaderArgs()
{

return new TextLoader.Arguments()
{
Separator = "comma",
HasHeader = true,
Column = new[]
{
new TextLoader.Column()
{
Name = "SepalLength",
Source = new [] { new TextLoader.Range() { Min=0, Max=0} },
Type = DataKind.R4
},
new TextLoader.Column()
{
Name = "SepalWidth",
Source = new [] { new TextLoader.Range() { Min=1, Max=1} },
Type = DataKind.R4
},
new TextLoader.Column()
{
Name = "PetalLength",
Source = new [] { new TextLoader.Range() { Min=2, Max=2} },
Type = DataKind.R4
},
new TextLoader.Column()
{
Name = "PetalWidth",
Source = new [] { new TextLoader.Range() { Min=3, Max=3} },
Type = DataKind.R4
},
new TextLoader.Column()
{
Name = "Label",
Source = new [] { new TextLoader.Range() { Min=4, Max=4} },
Type = DataKind.Text
}
}
{
new TextLoader.Column("SepalLength", DataKind.R4, 0),
new TextLoader.Column("SepalWidth", DataKind.R4, 1),
new TextLoader.Column("PetalLength", DataKind.R4, 2),
new TextLoader.Column("PetalWidth",DataKind.R4, 3),
new TextLoader.Column("Label", DataKind.Text, 4)
}
};
}
private static TextLoader.Arguments MakeSentimentTextLoaderArgs()
@@ -129,21 +103,10 @@ private static TextLoader.Arguments MakeSentimentTextLoaderArgs()
Separator = "tab",
HasHeader = true,
Column = new[]
{
new TextLoader.Column()
{
Name = "Label",
Source = new [] { new TextLoader.Range() { Min=0, Max=0} },
Type = DataKind.BL
},

new TextLoader.Column()
{
Name = "SentimentText",
Source = new [] { new TextLoader.Range() { Min=1, Max=1} },
Type = DataKind.Text
}
}
{
new TextLoader.Column("Label", DataKind.BL, 0),
new TextLoader.Column("SentimentText", DataKind.Text, 1)
}
};
}
}
Original file line number Diff line number Diff line change
@@ -29,37 +29,13 @@ public void TrainAndPredictIrisModelUsingDirectInstantiationTest()
new TextLoader.Arguments()
{
HasHeader = false,
Column = new[] {
new TextLoader.Column()
{
Name = "Label",
Source = new [] { new TextLoader.Range() { Min = 0, Max = 0} },
Type = DataKind.R4
},
new TextLoader.Column()
{
Name = "SepalLength",
Source = new [] { new TextLoader.Range() { Min = 1, Max = 1} },
Type = DataKind.R4
},
new TextLoader.Column()
{
Name = "SepalWidth",
Source = new [] { new TextLoader.Range() { Min = 2, Max = 2} },
Type = DataKind.R4
},
new TextLoader.Column()
{
Name = "PetalLength",
Source = new [] { new TextLoader.Range() { Min = 3, Max = 3} },
Type = DataKind.R4
},
new TextLoader.Column()
{
Name = "PetalWidth",
Source = new [] { new TextLoader.Range() { Min = 4, Max = 4} },
Type = DataKind.R4
}
Column = new[]
{
new TextLoader.Column("Label", DataKind.R4, 0),
new TextLoader.Column("SepalLength", DataKind.R4, 1),
new TextLoader.Column("SepalWidth", DataKind.R4, 2),
new TextLoader.Column("PetalLength", DataKind.R4, 3),
new TextLoader.Column("PetalWidth", DataKind.R4, 4)
}
}, new MultiFileSource(dataPath));

Original file line number Diff line number Diff line change
@@ -36,19 +36,8 @@ public void TrainAndPredictSentimentModelWithDirectionInstantiationTest()
HasHeader = true,
Column = new[]
{
new TextLoader.Column()
{
Name = "Label",
Source = new [] { new TextLoader.Range() { Min=0, Max=0} },
Type = DataKind.Num
},

new TextLoader.Column()
{
Name = "SentimentText",
Source = new [] { new TextLoader.Range() { Min=1, Max=1} },
Type = DataKind.Text
}
new TextLoader.Column("Label", DataKind.Num, 0),
new TextLoader.Column("SentimentText", DataKind.Text, 1)
}
}, new MultiFileSource(dataPath));

@@ -116,19 +105,8 @@ public void TrainAndPredictSentimentModelWithDirectionInstantiationTestWithWordE
HasHeader = true,
Column = new[]
{
new TextLoader.Column()
{
Name = "Label",
Source = new [] { new TextLoader.Range() { Min=0, Max=0} },
Type = DataKind.Num
},

new TextLoader.Column()
{
Name = "SentimentText",
Source = new [] { new TextLoader.Range() { Min=1, Max=1} },
Type = DataKind.Text
}
new TextLoader.Column("Label", DataKind.Num, 0),
new TextLoader.Column("SentimentText", DataKind.Text, 1)
}
}, new MultiFileSource(dataPath));