Skip to content

Commit e9a757b

Browse files
committed
Add EarlyStoppingCriteria.cs.
1 parent 7a23bbc commit e9a757b

File tree

1 file changed

+358
-0
lines changed

1 file changed

+358
-0
lines changed
Lines changed: 358 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,358 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.Collections.Generic;
6+
using Microsoft.ML;
7+
using Microsoft.ML.CommandLine;
8+
using Microsoft.ML.EntryPoints;
9+
using Microsoft.ML.Trainers.FastTree;
10+
using Float = System.Single;
11+
12+
[assembly: LoadableClass(typeof(TolerantEarlyStoppingCriterion), typeof(TolerantEarlyStoppingCriterion.Options), typeof(SignatureEarlyStoppingCriterion), "Tolerant (TR)", "tr")]
13+
[assembly: LoadableClass(typeof(GLEarlyStoppingCriterion), typeof(GLEarlyStoppingCriterion.Options), typeof(SignatureEarlyStoppingCriterion), "Loss of Generality (GL)", "gl")]
14+
[assembly: LoadableClass(typeof(LPEarlyStoppingCriterion), typeof(LPEarlyStoppingCriterion.Options), typeof(SignatureEarlyStoppingCriterion), "Low Progress (LP)", "lp")]
15+
[assembly: LoadableClass(typeof(PQEarlyStoppingCriterion), typeof(PQEarlyStoppingCriterion.Options), typeof(SignatureEarlyStoppingCriterion), "Generality to Progress Ratio (PQ)", "pq")]
16+
[assembly: LoadableClass(typeof(UPEarlyStoppingCriterion), typeof(UPEarlyStoppingCriterion.Options), typeof(SignatureEarlyStoppingCriterion), "Consecutive Loss in Generality (UP)", "up")]
17+
18+
[assembly: EntryPointModule(typeof(TolerantEarlyStoppingCriterion))]
19+
[assembly: EntryPointModule(typeof(GLEarlyStoppingCriterion))]
20+
[assembly: EntryPointModule(typeof(LPEarlyStoppingCriterion))]
21+
[assembly: EntryPointModule(typeof(PQEarlyStoppingCriterion))]
22+
[assembly: EntryPointModule(typeof(UPEarlyStoppingCriterion))]
23+
24+
namespace Microsoft.ML.Trainers.FastTree
25+
{
26+
internal delegate void SignatureEarlyStoppingCriterion(bool lowerIsBetter);
27+
28+
// These criteria will be used in FastTree and NeuralNets.
29+
public abstract class IEarlyStoppingCriterion
30+
{
31+
/// <summary>
32+
/// Check if the learning should stop or not.
33+
/// </summary>
34+
/// <param name="validationScore">A non negative number. Higher score means better result unless "_lowerIsBetter" is true.</param>
35+
/// <param name="trainingScore">A non negative number. Higher score means better result unless "_lowerIsBetter" is true.</param>
36+
/// <param name="isBestCandidate">True if the current result is the best ever.</param>
37+
/// <returns>If true, the learning should stop.</returns>
38+
public abstract bool CheckScore(Float validationScore, Float trainingScore, out bool isBestCandidate);
39+
}
40+
41+
[TlcModule.ComponentKind("EarlyStoppingCriterion")]
42+
public interface IEarlyStoppingCriterionFactory : IComponentFactory<bool, IEarlyStoppingCriterion>
43+
{
44+
new IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter);
45+
}
46+
47+
public abstract class EarlyStoppingCriterion<TOptions> : IEarlyStoppingCriterion
48+
where TOptions : EarlyStoppingCriterion<TOptions>.OptionsBase
49+
{
50+
public abstract class OptionsBase { }
51+
52+
private Float _bestScore;
53+
54+
protected readonly TOptions EarlyStoppingCriterionOptions;
55+
protected readonly bool LowerIsBetter;
56+
protected Float BestScore {
57+
get { return _bestScore; }
58+
set
59+
{
60+
Contracts.Assert((LowerIsBetter && value <= _bestScore) || value >= _bestScore);
61+
_bestScore = value;
62+
}
63+
}
64+
65+
internal EarlyStoppingCriterion(TOptions options, bool lowerIsBetter)
66+
{
67+
EarlyStoppingCriterionOptions = options;
68+
LowerIsBetter = lowerIsBetter;
69+
_bestScore = LowerIsBetter ? Float.PositiveInfinity : Float.NegativeInfinity;
70+
}
71+
72+
/// <summary>
73+
/// Check if the given score is the best ever. The best score will be stored at this._bestScore.
74+
/// </summary>
75+
/// <param name="score">The latest score</param>
76+
/// <returns>True if the given score is the best ever.</returns>
77+
protected bool CheckBestScore(Float score)
78+
{
79+
bool isBestEver = ((score > BestScore) != LowerIsBetter);
80+
if (isBestEver)
81+
BestScore = score;
82+
83+
return isBestEver;
84+
}
85+
}
86+
87+
public sealed class TolerantEarlyStoppingCriterion : EarlyStoppingCriterion<TolerantEarlyStoppingCriterion.Options>
88+
{
89+
[TlcModule.Component(FriendlyName = "Tolerant (TR)", Name = "TR", Desc = "Stop if validation score exceeds threshold value.")]
90+
public sealed class Options : OptionsBase, IEarlyStoppingCriterionFactory
91+
{
92+
[Argument(ArgumentType.AtMostOnce, HelpText = "Tolerance threshold. (Non negative value)", ShortName = "th")]
93+
[TlcModule.Range(Min = 0.0f)]
94+
public float Threshold = 0.01f;
95+
96+
public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter)
97+
{
98+
return new TolerantEarlyStoppingCriterion(this, lowerIsBetter);
99+
}
100+
}
101+
102+
public TolerantEarlyStoppingCriterion(Options options, bool lowerIsBetter)
103+
: base(options, lowerIsBetter)
104+
{
105+
Contracts.CheckUserArg(EarlyStoppingCriterionOptions.Threshold >= 0, nameof(options.Threshold), "Must be non-negative.");
106+
}
107+
108+
public override bool CheckScore(Float validationScore, Float trainingScore, out bool isBestCandidate)
109+
{
110+
Contracts.Assert(validationScore >= 0);
111+
112+
isBestCandidate = CheckBestScore(validationScore);
113+
114+
if (LowerIsBetter)
115+
return (validationScore - BestScore > EarlyStoppingCriterionOptions.Threshold);
116+
else
117+
return (BestScore - validationScore > EarlyStoppingCriterionOptions.Threshold);
118+
}
119+
}
120+
121+
// For the detail of the following rules, see the following paper.
122+
// Lodwich, Aleksander, Yves Rangoni, and Thomas Breuel. "Evaluation of robustness and performance of early stopping rules with multi layer perceptrons."
123+
// Neural Networks, 2009. IJCNN 2009. International Joint Conference on. IEEE, 2009.
124+
125+
public abstract class MovingWindowEarlyStoppingCriterion : EarlyStoppingCriterion<MovingWindowEarlyStoppingCriterion.Options>
126+
{
127+
public class Options : OptionsBase
128+
{
129+
[Argument(ArgumentType.AtMostOnce, HelpText = "Threshold in range [0,1].", ShortName = "th")]
130+
[TlcModule.Range(Min = 0.0f, Max = 1.0f)]
131+
public Float Threshold = 0.01f;
132+
133+
[Argument(ArgumentType.AtMostOnce, HelpText = "The window size.", ShortName = "w")]
134+
[TlcModule.Range(Inf = 0)]
135+
public int WindowSize = 5;
136+
}
137+
138+
protected Queue<Float> PastScores;
139+
140+
private protected MovingWindowEarlyStoppingCriterion(Options args, bool lowerIsBetter)
141+
: base(args, lowerIsBetter)
142+
{
143+
Contracts.CheckUserArg(0 <= EarlyStoppingCriterionOptions.Threshold && args.Threshold <= 1, nameof(args.Threshold), "Must be in range [0,1].");
144+
Contracts.CheckUserArg(EarlyStoppingCriterionOptions.WindowSize > 0, nameof(args.WindowSize), "Must be positive.");
145+
146+
PastScores = new Queue<Float>(EarlyStoppingCriterionOptions.WindowSize);
147+
}
148+
149+
/// <summary>
150+
/// Calculate the average score in the given list of scores.
151+
/// </summary>
152+
/// <returns>The moving average.</returns>
153+
private Float GetRecentAvg(Queue<Float> recentScores)
154+
{
155+
Float avg = 0;
156+
157+
foreach (Float score in recentScores)
158+
avg += score;
159+
160+
Contracts.Assert(recentScores.Count > 0);
161+
return avg / recentScores.Count;
162+
}
163+
164+
/// <summary>
165+
/// Get the best score in the given list of scores.
166+
/// </summary>
167+
/// <param name="recentScores">The list of scores.</param>
168+
/// <returns>The best score.</returns>
169+
private Float GetRecentBest(IEnumerable<Float> recentScores)
170+
{
171+
Float recentBestScore = LowerIsBetter ? Float.PositiveInfinity : Float.NegativeInfinity;
172+
foreach (Float score in recentScores)
173+
{
174+
if ((score > recentBestScore) != LowerIsBetter)
175+
recentBestScore = score;
176+
}
177+
178+
return recentBestScore;
179+
}
180+
181+
protected bool CheckRecentScores(Float score, int windowSize, out Float recentBest, out Float recentAverage)
182+
{
183+
if (PastScores.Count >= windowSize)
184+
{
185+
PastScores.Dequeue();
186+
PastScores.Enqueue(score);
187+
recentAverage = GetRecentAvg(PastScores);
188+
recentBest = GetRecentBest(PastScores);
189+
return true;
190+
}
191+
else
192+
{
193+
PastScores.Enqueue(score);
194+
recentBest = default(Float);
195+
recentAverage = default(Float);
196+
return false;
197+
}
198+
}
199+
}
200+
201+
/// <summary>
202+
/// Loss of Generality (GL).
203+
/// </summary>
204+
public sealed class GLEarlyStoppingCriterion : EarlyStoppingCriterion<GLEarlyStoppingCriterion.Options>
205+
{
206+
[TlcModule.Component(FriendlyName = "Loss of Generality (GL)", Name = "GL",
207+
Desc = "Stop in case of loss of generality.")]
208+
public sealed class Options : OptionsBase, IEarlyStoppingCriterionFactory
209+
{
210+
[Argument(ArgumentType.AtMostOnce, HelpText = "Threshold in range [0,1].", ShortName = "th")]
211+
[TlcModule.Range(Min = 0.0f, Max = 1.0f)]
212+
public float Threshold = 0.01f;
213+
214+
public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter)
215+
{
216+
return new GLEarlyStoppingCriterion(this, lowerIsBetter);
217+
}
218+
}
219+
220+
public GLEarlyStoppingCriterion(Options options, bool lowerIsBetter)
221+
: base(options, lowerIsBetter)
222+
{
223+
Contracts.CheckUserArg(0 <= EarlyStoppingCriterionOptions.Threshold && options.Threshold <= 1, nameof(options.Threshold), "Must be in range [0,1].");
224+
}
225+
226+
public override bool CheckScore(Float validationScore, Float trainingScore, out bool isBestCandidate)
227+
{
228+
Contracts.Assert(validationScore >= 0);
229+
230+
isBestCandidate = CheckBestScore(validationScore);
231+
232+
if (LowerIsBetter)
233+
return (validationScore > (1 + EarlyStoppingCriterionOptions.Threshold) * BestScore);
234+
else
235+
return (validationScore < (1 - EarlyStoppingCriterionOptions.Threshold) * BestScore);
236+
}
237+
}
238+
239+
/// <summary>
240+
/// Low Progress (LP).
241+
/// This rule fires when the improvements on the score stall.
242+
/// </summary>
243+
public sealed class LPEarlyStoppingCriterion : MovingWindowEarlyStoppingCriterion
244+
{
245+
[TlcModule.Component(FriendlyName = "Low Progress (LP)", Name = "LP", Desc = "Stops in case of low progress.")]
246+
public new sealed class Options : MovingWindowEarlyStoppingCriterion.Options, IEarlyStoppingCriterionFactory
247+
{
248+
public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter)
249+
{
250+
return new LPEarlyStoppingCriterion(this, lowerIsBetter);
251+
}
252+
}
253+
254+
public LPEarlyStoppingCriterion(Options options, bool lowerIsBetter)
255+
: base(options, lowerIsBetter) { }
256+
257+
public override bool CheckScore(Float validationScore, Float trainingScore, out bool isBestCandidate)
258+
{
259+
Contracts.Assert(validationScore >= 0);
260+
Contracts.Assert(trainingScore >= 0);
261+
262+
isBestCandidate = CheckBestScore(validationScore);
263+
264+
Float recentBest;
265+
Float recentAverage;
266+
if (CheckRecentScores(trainingScore, EarlyStoppingCriterionOptions.WindowSize, out recentBest, out recentAverage))
267+
{
268+
if (LowerIsBetter)
269+
return (recentAverage <= (1 + EarlyStoppingCriterionOptions.Threshold) * recentBest);
270+
else
271+
return (recentAverage >= (1 - EarlyStoppingCriterionOptions.Threshold) * recentBest);
272+
}
273+
274+
return false;
275+
}
276+
}
277+
278+
/// <summary>
279+
/// Generality to Progress Ratio (PQ).
280+
/// </summary>
281+
public sealed class PQEarlyStoppingCriterion : MovingWindowEarlyStoppingCriterion
282+
{
283+
[TlcModule.Component(FriendlyName = "Generality to Progress Ratio (PQ)", Name = "PQ", Desc = "Stops in case of generality to progress ration exceeds threshold.")]
284+
public new sealed class Options : MovingWindowEarlyStoppingCriterion.Options, IEarlyStoppingCriterionFactory
285+
{
286+
public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter)
287+
{
288+
return new PQEarlyStoppingCriterion(this, lowerIsBetter);
289+
}
290+
}
291+
292+
public PQEarlyStoppingCriterion(Options options, bool lowerIsBetter)
293+
: base(options, lowerIsBetter) { }
294+
295+
public override bool CheckScore(Float validationScore, Float trainingScore, out bool isBestCandidate)
296+
{
297+
Contracts.Assert(validationScore >= 0);
298+
Contracts.Assert(trainingScore >= 0);
299+
300+
isBestCandidate = CheckBestScore(validationScore);
301+
302+
Float recentBest;
303+
Float recentAverage;
304+
if (CheckRecentScores(trainingScore, EarlyStoppingCriterionOptions.WindowSize, out recentBest, out recentAverage))
305+
{
306+
if (LowerIsBetter)
307+
return (validationScore * recentBest >= (1 + EarlyStoppingCriterionOptions.Threshold) * BestScore * recentAverage);
308+
else
309+
return (validationScore * recentBest <= (1 - EarlyStoppingCriterionOptions.Threshold) * BestScore * recentAverage);
310+
}
311+
312+
return false;
313+
}
314+
}
315+
316+
/// <summary>
317+
/// Consecutive Loss in Generality (UP).
318+
/// </summary>
319+
public sealed class UPEarlyStoppingCriterion : EarlyStoppingCriterion<UPEarlyStoppingCriterion.Options>
320+
{
321+
[TlcModule.Component(FriendlyName = "Consecutive Loss in Generality (UP)", Name = "UP",
322+
Desc = "Stops in case of consecutive loss in generality.")]
323+
public sealed class Options : OptionsBase, IEarlyStoppingCriterionFactory
324+
{
325+
[Argument(ArgumentType.AtMostOnce, HelpText = "The window size.", ShortName = "w")]
326+
[TlcModule.Range(Inf = 0)]
327+
public int WindowSize = 5;
328+
329+
public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter)
330+
{
331+
return new UPEarlyStoppingCriterion(this, lowerIsBetter);
332+
}
333+
}
334+
335+
private int _count;
336+
private Float _prevScore;
337+
338+
public UPEarlyStoppingCriterion(Options options, bool lowerIsBetter)
339+
: base(options, lowerIsBetter)
340+
{
341+
Contracts.CheckUserArg(EarlyStoppingCriterionOptions.WindowSize > 0, nameof(options.WindowSize), "Must be positive");
342+
343+
_prevScore = LowerIsBetter ? Float.PositiveInfinity : Float.NegativeInfinity;
344+
}
345+
346+
public override bool CheckScore(Float validationScore, Float trainingScore, out bool isBestCandidate)
347+
{
348+
Contracts.Assert(validationScore >= 0);
349+
350+
isBestCandidate = CheckBestScore(validationScore);
351+
352+
_count = ((validationScore < _prevScore) != LowerIsBetter) ? _count + 1 : 0;
353+
_prevScore = validationScore;
354+
355+
return (_count >= EarlyStoppingCriterionOptions.WindowSize);
356+
}
357+
}
358+
}

0 commit comments

Comments
 (0)