Skip to content

Commit 1e71b57

Browse files
authored
Added tests for text featurizer options (Part2). (#3036)
1 parent 5b22420 commit 1e71b57

File tree

1 file changed

+96
-6
lines changed

1 file changed

+96
-6
lines changed

test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs

+96-6
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ private class TestClass
3131
{
3232
public string A;
3333
public string[] OutputTokens;
34+
public float[] Features = null;
3435
}
3536

3637
[Fact]
@@ -41,7 +42,7 @@ public void TextFeaturizerWithPredefinedStopWordRemoverTest()
4142
var dataView = ML.Data.LoadFromEnumerable(data);
4243

4344
var options = new TextFeaturizingEstimator.Options() { StopWordsRemoverOptions = new StopWordsRemovingEstimator.Options(), OutputTokensColumnName = "OutputTokens" };
44-
var pipeline = ML.Transforms.Text.FeaturizeText("OutputText", options, "A");
45+
var pipeline = ML.Transforms.Text.FeaturizeText("Features", options, "A");
4546
var model = pipeline.Fit(dataView);
4647
var engine = model.CreatePredictionEngine<TestClass, TestClass>(ML);
4748
var prediction = engine.Predict(data[0]);
@@ -51,6 +52,95 @@ public void TextFeaturizerWithPredefinedStopWordRemoverTest()
5152
Assert.Equal("stop words", string.Join(" ", prediction.OutputTokens));
5253
}
5354

55+
[Fact]
56+
public void TextFeaturizerWithWordFeatureExtractorTest()
57+
{
58+
var data = new[] { new TestClass() { A = "This is some text in english", OutputTokens=null},
59+
new TestClass() { A = "This is another example", OutputTokens=null } };
60+
var dataView = ML.Data.LoadFromEnumerable(data);
61+
62+
var options = new TextFeaturizingEstimator.Options()
63+
{
64+
WordFeatureExtractor = new WordBagEstimator.Options() { NgramLength = 1 },
65+
CharFeatureExtractor = null,
66+
Norm = TextFeaturizingEstimator.NormFunction.None,
67+
OutputTokensColumnName = "OutputTokens"
68+
};
69+
var pipeline = ML.Transforms.Text.FeaturizeText("Features", options, "A");
70+
var model = pipeline.Fit(dataView);
71+
var engine = model.CreatePredictionEngine<TestClass, TestClass>(ML);
72+
73+
var prediction = engine.Predict(data[0]);
74+
Assert.Equal(data[0].A.ToLower(), string.Join(" ", prediction.OutputTokens));
75+
var expected = new float[] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f, 0.0f };
76+
Assert.Equal(expected, prediction.Features);
77+
78+
prediction = engine.Predict(data[1]);
79+
Assert.Equal(data[1].A.ToLower(), string.Join(" ", prediction.OutputTokens));
80+
expected = new float[] { 1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 1.0f };
81+
Assert.Equal(expected, prediction.Features);
82+
}
83+
84+
[Fact]
85+
public void TextFeaturizerWithCharFeatureExtractorTest()
86+
{
87+
var data = new[] { new TestClass() { A = "abc efg", OutputTokens=null},
88+
new TestClass() { A = "xyz", OutputTokens=null } };
89+
var dataView = ML.Data.LoadFromEnumerable(data);
90+
91+
var options = new TextFeaturizingEstimator.Options()
92+
{
93+
WordFeatureExtractor = null,
94+
CharFeatureExtractor = new WordBagEstimator.Options() { NgramLength = 1 },
95+
Norm = TextFeaturizingEstimator.NormFunction.None,
96+
OutputTokensColumnName = "OutputTokens"
97+
};
98+
var pipeline = ML.Transforms.Text.FeaturizeText("Features", options, "A");
99+
var model = pipeline.Fit(dataView);
100+
var engine = model.CreatePredictionEngine<TestClass, TestClass>(ML);
101+
102+
var prediction = engine.Predict(data[0]);
103+
Assert.Equal(data[0].A, string.Join(" ", prediction.OutputTokens));
104+
var expected = new float[] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f, 0.0f, 0.0f };
105+
Assert.Equal(expected, prediction.Features);
106+
107+
prediction = engine.Predict(data[1]);
108+
Assert.Equal(data[1].A, string.Join(" ", prediction.OutputTokens));
109+
expected = new float[] { 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 1.0f, 1.0f, 1.0f };
110+
Assert.Equal(expected, prediction.Features);
111+
}
112+
113+
[Fact]
114+
public void TextFeaturizerWithL2NormTest()
115+
{
116+
var data = new[] { new TestClass() { A = "abc xyz", OutputTokens=null},
117+
new TestClass() { A = "xyz", OutputTokens=null } };
118+
var dataView = ML.Data.LoadFromEnumerable(data);
119+
120+
var options = new TextFeaturizingEstimator.Options()
121+
{
122+
CharFeatureExtractor = new WordBagEstimator.Options() { NgramLength = 1},
123+
Norm = TextFeaturizingEstimator.NormFunction.L2,
124+
OutputTokensColumnName = "OutputTokens"
125+
};
126+
var pipeline = ML.Transforms.Text.FeaturizeText("Features", options, "A");
127+
var model = pipeline.Fit(dataView);
128+
var engine = model.CreatePredictionEngine<TestClass, TestClass>(ML);
129+
130+
var prediction = engine.Predict(data[0]);
131+
Assert.Equal(data[0].A, string.Join(" ", prediction.OutputTokens));
132+
var exp1 = 0.333333343f;
133+
var exp2 = 0.707106769f;
134+
var expected = new float[] { exp1, exp1, exp1, exp1, exp1, exp1, exp1, exp1, exp1, exp2, exp2 };
135+
Assert.Equal(expected, prediction.Features);
136+
137+
prediction = engine.Predict(data[1]);
138+
exp1 = 0.4472136f;
139+
Assert.Equal(data[1].A, string.Join(" ", prediction.OutputTokens));
140+
expected = new float[] { exp1, 0.0f, 0.0f, 0.0f, 0.0f, exp1, exp1, exp1, exp1, 0.0f, 1.0f };
141+
Assert.Equal(expected, prediction.Features);
142+
}
143+
54144
[Fact]
55145
public void TextFeaturizerWithCustomStopWordRemoverTest()
56146
{
@@ -67,7 +157,7 @@ public void TextFeaturizerWithCustomStopWordRemoverTest()
67157
OutputTokensColumnName = "OutputTokens",
68158
CaseMode = TextNormalizingEstimator.CaseMode.None
69159
};
70-
var pipeline = ML.Transforms.Text.FeaturizeText("OutputText", options, "A");
160+
var pipeline = ML.Transforms.Text.FeaturizeText("Features", options, "A");
71161
var model = pipeline.Fit(dataView);
72162
var engine = model.CreatePredictionEngine<TestClass, TestClass>(ML);
73163
var prediction = engine.Predict(data[0]);
@@ -84,7 +174,7 @@ private void TestCaseMode(IDataView dataView, TestClass[] data, TextNormalizingE
84174
CaseMode = caseMode,
85175
OutputTokensColumnName = "OutputTokens"
86176
};
87-
var pipeline = ML.Transforms.Text.FeaturizeText("OutputText", options, "A");
177+
var pipeline = ML.Transforms.Text.FeaturizeText("Features", options, "A");
88178
var model = pipeline.Fit(dataView);
89179
var engine = model.CreatePredictionEngine<TestClass, TestClass>(ML);
90180
var prediction1 = engine.Predict(data[0]);
@@ -133,7 +223,7 @@ private void TestKeepNumbers(IDataView dataView, TestClass[] data, bool keepNumb
133223
CaseMode = TextNormalizingEstimator.CaseMode.None,
134224
OutputTokensColumnName = "OutputTokens"
135225
};
136-
var pipeline = ML.Transforms.Text.FeaturizeText("OutputText", options, "A");
226+
var pipeline = ML.Transforms.Text.FeaturizeText("Features", options, "A");
137227
var model = pipeline.Fit(dataView);
138228
var engine = model.CreatePredictionEngine<TestClass, TestClass>(ML);
139229
var prediction1 = engine.Predict(data[0]);
@@ -170,7 +260,7 @@ private void TestKeepPunctuations(IDataView dataView, TestClass[] data, bool kee
170260
CaseMode = TextNormalizingEstimator.CaseMode.None,
171261
OutputTokensColumnName = "OutputTokens"
172262
};
173-
var pipeline = ML.Transforms.Text.FeaturizeText("OutputText", options, "A");
263+
var pipeline = ML.Transforms.Text.FeaturizeText("Features", options, "A");
174264
var model = pipeline.Fit(dataView);
175265
var engine = model.CreatePredictionEngine<TestClass, TestClass>(ML);
176266
var prediction1 = engine.Predict(data[0]);
@@ -208,7 +298,7 @@ private void TestKeepDiacritics(IDataView dataView, TestClass[] data, bool keepD
208298
CaseMode = TextNormalizingEstimator.CaseMode.None,
209299
OutputTokensColumnName = "OutputTokens"
210300
};
211-
var pipeline = ML.Transforms.Text.FeaturizeText("OutputText", options, "A");
301+
var pipeline = ML.Transforms.Text.FeaturizeText("Features", options, "A");
212302
var model = pipeline.Fit(dataView);
213303
var engine = model.CreatePredictionEngine<TestClass, TestClass>(ML);
214304
var prediction1 = engine.Predict(data[0]);

0 commit comments

Comments
 (0)