|
| 1 | +// Licensed to the .NET Foundation under one or more agreements. |
| 2 | +// The .NET Foundation licenses this file to you under the MIT license. |
| 3 | +// See the LICENSE file in the project root for more information. |
| 4 | + |
| 5 | +using System.Collections.Generic; |
| 6 | +using Microsoft.ML; |
| 7 | +using Microsoft.ML.CommandLine; |
| 8 | +using Microsoft.ML.EntryPoints; |
| 9 | +using Microsoft.ML.Trainers.FastTree; |
| 10 | +using Float = System.Single; |
| 11 | + |
| 12 | +[assembly: LoadableClass(typeof(TolerantEarlyStoppingCriterion), typeof(TolerantEarlyStoppingCriterion.Options), typeof(SignatureEarlyStoppingCriterion), "Tolerant (TR)", "tr")] |
| 13 | +[assembly: LoadableClass(typeof(GLEarlyStoppingCriterion), typeof(GLEarlyStoppingCriterion.Options), typeof(SignatureEarlyStoppingCriterion), "Loss of Generality (GL)", "gl")] |
| 14 | +[assembly: LoadableClass(typeof(LPEarlyStoppingCriterion), typeof(LPEarlyStoppingCriterion.Options), typeof(SignatureEarlyStoppingCriterion), "Low Progress (LP)", "lp")] |
| 15 | +[assembly: LoadableClass(typeof(PQEarlyStoppingCriterion), typeof(PQEarlyStoppingCriterion.Options), typeof(SignatureEarlyStoppingCriterion), "Generality to Progress Ratio (PQ)", "pq")] |
| 16 | +[assembly: LoadableClass(typeof(UPEarlyStoppingCriterion), typeof(UPEarlyStoppingCriterion.Options), typeof(SignatureEarlyStoppingCriterion), "Consecutive Loss in Generality (UP)", "up")] |
| 17 | + |
| 18 | +[assembly: EntryPointModule(typeof(TolerantEarlyStoppingCriterion))] |
| 19 | +[assembly: EntryPointModule(typeof(GLEarlyStoppingCriterion))] |
| 20 | +[assembly: EntryPointModule(typeof(LPEarlyStoppingCriterion))] |
| 21 | +[assembly: EntryPointModule(typeof(PQEarlyStoppingCriterion))] |
| 22 | +[assembly: EntryPointModule(typeof(UPEarlyStoppingCriterion))] |
| 23 | + |
| 24 | +namespace Microsoft.ML.Trainers.FastTree |
| 25 | +{ |
| 26 | + internal delegate void SignatureEarlyStoppingCriterion(bool lowerIsBetter); |
| 27 | + |
| 28 | + // These criteria will be used in FastTree and NeuralNets. |
| 29 | + public abstract class IEarlyStoppingCriterion |
| 30 | + { |
| 31 | + /// <summary> |
| 32 | + /// Check if the learning should stop or not. |
| 33 | + /// </summary> |
| 34 | + /// <param name="validationScore">A non negative number. Higher score means better result unless "_lowerIsBetter" is true.</param> |
| 35 | + /// <param name="trainingScore">A non negative number. Higher score means better result unless "_lowerIsBetter" is true.</param> |
| 36 | + /// <param name="isBestCandidate">True if the current result is the best ever.</param> |
| 37 | + /// <returns>If true, the learning should stop.</returns> |
| 38 | + public abstract bool CheckScore(Float validationScore, Float trainingScore, out bool isBestCandidate); |
| 39 | + } |
| 40 | + |
| 41 | + [TlcModule.ComponentKind("EarlyStoppingCriterion")] |
| 42 | + public interface IEarlyStoppingCriterionFactory : IComponentFactory<bool, IEarlyStoppingCriterion> |
| 43 | + { |
| 44 | + new IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter); |
| 45 | + } |
| 46 | + |
| 47 | + public abstract class EarlyStoppingCriterion<TOptions> : IEarlyStoppingCriterion |
| 48 | + where TOptions : EarlyStoppingCriterion<TOptions>.OptionsBase |
| 49 | + { |
| 50 | + public abstract class OptionsBase { } |
| 51 | + |
| 52 | + private Float _bestScore; |
| 53 | + |
| 54 | + protected readonly TOptions EarlyStoppingCriterionOptions; |
| 55 | + protected readonly bool LowerIsBetter; |
| 56 | + protected Float BestScore { |
| 57 | + get { return _bestScore; } |
| 58 | + set |
| 59 | + { |
| 60 | + Contracts.Assert((LowerIsBetter && value <= _bestScore) || value >= _bestScore); |
| 61 | + _bestScore = value; |
| 62 | + } |
| 63 | + } |
| 64 | + |
| 65 | + internal EarlyStoppingCriterion(TOptions options, bool lowerIsBetter) |
| 66 | + { |
| 67 | + EarlyStoppingCriterionOptions = options; |
| 68 | + LowerIsBetter = lowerIsBetter; |
| 69 | + _bestScore = LowerIsBetter ? Float.PositiveInfinity : Float.NegativeInfinity; |
| 70 | + } |
| 71 | + |
| 72 | + /// <summary> |
| 73 | + /// Check if the given score is the best ever. The best score will be stored at this._bestScore. |
| 74 | + /// </summary> |
| 75 | + /// <param name="score">The latest score</param> |
| 76 | + /// <returns>True if the given score is the best ever.</returns> |
| 77 | + protected bool CheckBestScore(Float score) |
| 78 | + { |
| 79 | + bool isBestEver = ((score > BestScore) != LowerIsBetter); |
| 80 | + if (isBestEver) |
| 81 | + BestScore = score; |
| 82 | + |
| 83 | + return isBestEver; |
| 84 | + } |
| 85 | + } |
| 86 | + |
| 87 | + public sealed class TolerantEarlyStoppingCriterion : EarlyStoppingCriterion<TolerantEarlyStoppingCriterion.Options> |
| 88 | + { |
| 89 | + [TlcModule.Component(FriendlyName = "Tolerant (TR)", Name = "TR", Desc = "Stop if validation score exceeds threshold value.")] |
| 90 | + public sealed class Options : OptionsBase, IEarlyStoppingCriterionFactory |
| 91 | + { |
| 92 | + [Argument(ArgumentType.AtMostOnce, HelpText = "Tolerance threshold. (Non negative value)", ShortName = "th")] |
| 93 | + [TlcModule.Range(Min = 0.0f)] |
| 94 | + public float Threshold = 0.01f; |
| 95 | + |
| 96 | + public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter) |
| 97 | + { |
| 98 | + return new TolerantEarlyStoppingCriterion(this, lowerIsBetter); |
| 99 | + } |
| 100 | + } |
| 101 | + |
| 102 | + public TolerantEarlyStoppingCriterion(Options options, bool lowerIsBetter) |
| 103 | + : base(options, lowerIsBetter) |
| 104 | + { |
| 105 | + Contracts.CheckUserArg(EarlyStoppingCriterionOptions.Threshold >= 0, nameof(options.Threshold), "Must be non-negative."); |
| 106 | + } |
| 107 | + |
| 108 | + public override bool CheckScore(Float validationScore, Float trainingScore, out bool isBestCandidate) |
| 109 | + { |
| 110 | + Contracts.Assert(validationScore >= 0); |
| 111 | + |
| 112 | + isBestCandidate = CheckBestScore(validationScore); |
| 113 | + |
| 114 | + if (LowerIsBetter) |
| 115 | + return (validationScore - BestScore > EarlyStoppingCriterionOptions.Threshold); |
| 116 | + else |
| 117 | + return (BestScore - validationScore > EarlyStoppingCriterionOptions.Threshold); |
| 118 | + } |
| 119 | + } |
| 120 | + |
| 121 | + // For the detail of the following rules, see the following paper. |
| 122 | + // Lodwich, Aleksander, Yves Rangoni, and Thomas Breuel. "Evaluation of robustness and performance of early stopping rules with multi layer perceptrons." |
| 123 | + // Neural Networks, 2009. IJCNN 2009. International Joint Conference on. IEEE, 2009. |
| 124 | + |
| 125 | + public abstract class MovingWindowEarlyStoppingCriterion : EarlyStoppingCriterion<MovingWindowEarlyStoppingCriterion.Options> |
| 126 | + { |
| 127 | + public class Options : OptionsBase |
| 128 | + { |
| 129 | + [Argument(ArgumentType.AtMostOnce, HelpText = "Threshold in range [0,1].", ShortName = "th")] |
| 130 | + [TlcModule.Range(Min = 0.0f, Max = 1.0f)] |
| 131 | + public Float Threshold = 0.01f; |
| 132 | + |
| 133 | + [Argument(ArgumentType.AtMostOnce, HelpText = "The window size.", ShortName = "w")] |
| 134 | + [TlcModule.Range(Inf = 0)] |
| 135 | + public int WindowSize = 5; |
| 136 | + } |
| 137 | + |
| 138 | + protected Queue<Float> PastScores; |
| 139 | + |
| 140 | + private protected MovingWindowEarlyStoppingCriterion(Options args, bool lowerIsBetter) |
| 141 | + : base(args, lowerIsBetter) |
| 142 | + { |
| 143 | + Contracts.CheckUserArg(0 <= EarlyStoppingCriterionOptions.Threshold && args.Threshold <= 1, nameof(args.Threshold), "Must be in range [0,1]."); |
| 144 | + Contracts.CheckUserArg(EarlyStoppingCriterionOptions.WindowSize > 0, nameof(args.WindowSize), "Must be positive."); |
| 145 | + |
| 146 | + PastScores = new Queue<Float>(EarlyStoppingCriterionOptions.WindowSize); |
| 147 | + } |
| 148 | + |
| 149 | + /// <summary> |
| 150 | + /// Calculate the average score in the given list of scores. |
| 151 | + /// </summary> |
| 152 | + /// <returns>The moving average.</returns> |
| 153 | + private Float GetRecentAvg(Queue<Float> recentScores) |
| 154 | + { |
| 155 | + Float avg = 0; |
| 156 | + |
| 157 | + foreach (Float score in recentScores) |
| 158 | + avg += score; |
| 159 | + |
| 160 | + Contracts.Assert(recentScores.Count > 0); |
| 161 | + return avg / recentScores.Count; |
| 162 | + } |
| 163 | + |
| 164 | + /// <summary> |
| 165 | + /// Get the best score in the given list of scores. |
| 166 | + /// </summary> |
| 167 | + /// <param name="recentScores">The list of scores.</param> |
| 168 | + /// <returns>The best score.</returns> |
| 169 | + private Float GetRecentBest(IEnumerable<Float> recentScores) |
| 170 | + { |
| 171 | + Float recentBestScore = LowerIsBetter ? Float.PositiveInfinity : Float.NegativeInfinity; |
| 172 | + foreach (Float score in recentScores) |
| 173 | + { |
| 174 | + if ((score > recentBestScore) != LowerIsBetter) |
| 175 | + recentBestScore = score; |
| 176 | + } |
| 177 | + |
| 178 | + return recentBestScore; |
| 179 | + } |
| 180 | + |
| 181 | + protected bool CheckRecentScores(Float score, int windowSize, out Float recentBest, out Float recentAverage) |
| 182 | + { |
| 183 | + if (PastScores.Count >= windowSize) |
| 184 | + { |
| 185 | + PastScores.Dequeue(); |
| 186 | + PastScores.Enqueue(score); |
| 187 | + recentAverage = GetRecentAvg(PastScores); |
| 188 | + recentBest = GetRecentBest(PastScores); |
| 189 | + return true; |
| 190 | + } |
| 191 | + else |
| 192 | + { |
| 193 | + PastScores.Enqueue(score); |
| 194 | + recentBest = default(Float); |
| 195 | + recentAverage = default(Float); |
| 196 | + return false; |
| 197 | + } |
| 198 | + } |
| 199 | + } |
| 200 | + |
| 201 | + /// <summary> |
| 202 | + /// Loss of Generality (GL). |
| 203 | + /// </summary> |
| 204 | + public sealed class GLEarlyStoppingCriterion : EarlyStoppingCriterion<GLEarlyStoppingCriterion.Options> |
| 205 | + { |
| 206 | + [TlcModule.Component(FriendlyName = "Loss of Generality (GL)", Name = "GL", |
| 207 | + Desc = "Stop in case of loss of generality.")] |
| 208 | + public sealed class Options : OptionsBase, IEarlyStoppingCriterionFactory |
| 209 | + { |
| 210 | + [Argument(ArgumentType.AtMostOnce, HelpText = "Threshold in range [0,1].", ShortName = "th")] |
| 211 | + [TlcModule.Range(Min = 0.0f, Max = 1.0f)] |
| 212 | + public float Threshold = 0.01f; |
| 213 | + |
| 214 | + public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter) |
| 215 | + { |
| 216 | + return new GLEarlyStoppingCriterion(this, lowerIsBetter); |
| 217 | + } |
| 218 | + } |
| 219 | + |
| 220 | + public GLEarlyStoppingCriterion(Options options, bool lowerIsBetter) |
| 221 | + : base(options, lowerIsBetter) |
| 222 | + { |
| 223 | + Contracts.CheckUserArg(0 <= EarlyStoppingCriterionOptions.Threshold && options.Threshold <= 1, nameof(options.Threshold), "Must be in range [0,1]."); |
| 224 | + } |
| 225 | + |
| 226 | + public override bool CheckScore(Float validationScore, Float trainingScore, out bool isBestCandidate) |
| 227 | + { |
| 228 | + Contracts.Assert(validationScore >= 0); |
| 229 | + |
| 230 | + isBestCandidate = CheckBestScore(validationScore); |
| 231 | + |
| 232 | + if (LowerIsBetter) |
| 233 | + return (validationScore > (1 + EarlyStoppingCriterionOptions.Threshold) * BestScore); |
| 234 | + else |
| 235 | + return (validationScore < (1 - EarlyStoppingCriterionOptions.Threshold) * BestScore); |
| 236 | + } |
| 237 | + } |
| 238 | + |
| 239 | + /// <summary> |
| 240 | + /// Low Progress (LP). |
| 241 | + /// This rule fires when the improvements on the score stall. |
| 242 | + /// </summary> |
| 243 | + public sealed class LPEarlyStoppingCriterion : MovingWindowEarlyStoppingCriterion |
| 244 | + { |
| 245 | + [TlcModule.Component(FriendlyName = "Low Progress (LP)", Name = "LP", Desc = "Stops in case of low progress.")] |
| 246 | + public new sealed class Options : MovingWindowEarlyStoppingCriterion.Options, IEarlyStoppingCriterionFactory |
| 247 | + { |
| 248 | + public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter) |
| 249 | + { |
| 250 | + return new LPEarlyStoppingCriterion(this, lowerIsBetter); |
| 251 | + } |
| 252 | + } |
| 253 | + |
| 254 | + public LPEarlyStoppingCriterion(Options options, bool lowerIsBetter) |
| 255 | + : base(options, lowerIsBetter) { } |
| 256 | + |
| 257 | + public override bool CheckScore(Float validationScore, Float trainingScore, out bool isBestCandidate) |
| 258 | + { |
| 259 | + Contracts.Assert(validationScore >= 0); |
| 260 | + Contracts.Assert(trainingScore >= 0); |
| 261 | + |
| 262 | + isBestCandidate = CheckBestScore(validationScore); |
| 263 | + |
| 264 | + Float recentBest; |
| 265 | + Float recentAverage; |
| 266 | + if (CheckRecentScores(trainingScore, EarlyStoppingCriterionOptions.WindowSize, out recentBest, out recentAverage)) |
| 267 | + { |
| 268 | + if (LowerIsBetter) |
| 269 | + return (recentAverage <= (1 + EarlyStoppingCriterionOptions.Threshold) * recentBest); |
| 270 | + else |
| 271 | + return (recentAverage >= (1 - EarlyStoppingCriterionOptions.Threshold) * recentBest); |
| 272 | + } |
| 273 | + |
| 274 | + return false; |
| 275 | + } |
| 276 | + } |
| 277 | + |
| 278 | + /// <summary> |
| 279 | + /// Generality to Progress Ratio (PQ). |
| 280 | + /// </summary> |
| 281 | + public sealed class PQEarlyStoppingCriterion : MovingWindowEarlyStoppingCriterion |
| 282 | + { |
| 283 | + [TlcModule.Component(FriendlyName = "Generality to Progress Ratio (PQ)", Name = "PQ", Desc = "Stops in case of generality to progress ration exceeds threshold.")] |
| 284 | + public new sealed class Options : MovingWindowEarlyStoppingCriterion.Options, IEarlyStoppingCriterionFactory |
| 285 | + { |
| 286 | + public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter) |
| 287 | + { |
| 288 | + return new PQEarlyStoppingCriterion(this, lowerIsBetter); |
| 289 | + } |
| 290 | + } |
| 291 | + |
| 292 | + public PQEarlyStoppingCriterion(Options options, bool lowerIsBetter) |
| 293 | + : base(options, lowerIsBetter) { } |
| 294 | + |
| 295 | + public override bool CheckScore(Float validationScore, Float trainingScore, out bool isBestCandidate) |
| 296 | + { |
| 297 | + Contracts.Assert(validationScore >= 0); |
| 298 | + Contracts.Assert(trainingScore >= 0); |
| 299 | + |
| 300 | + isBestCandidate = CheckBestScore(validationScore); |
| 301 | + |
| 302 | + Float recentBest; |
| 303 | + Float recentAverage; |
| 304 | + if (CheckRecentScores(trainingScore, EarlyStoppingCriterionOptions.WindowSize, out recentBest, out recentAverage)) |
| 305 | + { |
| 306 | + if (LowerIsBetter) |
| 307 | + return (validationScore * recentBest >= (1 + EarlyStoppingCriterionOptions.Threshold) * BestScore * recentAverage); |
| 308 | + else |
| 309 | + return (validationScore * recentBest <= (1 - EarlyStoppingCriterionOptions.Threshold) * BestScore * recentAverage); |
| 310 | + } |
| 311 | + |
| 312 | + return false; |
| 313 | + } |
| 314 | + } |
| 315 | + |
| 316 | + /// <summary> |
| 317 | + /// Consecutive Loss in Generality (UP). |
| 318 | + /// </summary> |
| 319 | + public sealed class UPEarlyStoppingCriterion : EarlyStoppingCriterion<UPEarlyStoppingCriterion.Options> |
| 320 | + { |
| 321 | + [TlcModule.Component(FriendlyName = "Consecutive Loss in Generality (UP)", Name = "UP", |
| 322 | + Desc = "Stops in case of consecutive loss in generality.")] |
| 323 | + public sealed class Options : OptionsBase, IEarlyStoppingCriterionFactory |
| 324 | + { |
| 325 | + [Argument(ArgumentType.AtMostOnce, HelpText = "The window size.", ShortName = "w")] |
| 326 | + [TlcModule.Range(Inf = 0)] |
| 327 | + public int WindowSize = 5; |
| 328 | + |
| 329 | + public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter) |
| 330 | + { |
| 331 | + return new UPEarlyStoppingCriterion(this, lowerIsBetter); |
| 332 | + } |
| 333 | + } |
| 334 | + |
| 335 | + private int _count; |
| 336 | + private Float _prevScore; |
| 337 | + |
| 338 | + public UPEarlyStoppingCriterion(Options options, bool lowerIsBetter) |
| 339 | + : base(options, lowerIsBetter) |
| 340 | + { |
| 341 | + Contracts.CheckUserArg(EarlyStoppingCriterionOptions.WindowSize > 0, nameof(options.WindowSize), "Must be positive"); |
| 342 | + |
| 343 | + _prevScore = LowerIsBetter ? Float.PositiveInfinity : Float.NegativeInfinity; |
| 344 | + } |
| 345 | + |
| 346 | + public override bool CheckScore(Float validationScore, Float trainingScore, out bool isBestCandidate) |
| 347 | + { |
| 348 | + Contracts.Assert(validationScore >= 0); |
| 349 | + |
| 350 | + isBestCandidate = CheckBestScore(validationScore); |
| 351 | + |
| 352 | + _count = ((validationScore < _prevScore) != LowerIsBetter) ? _count + 1 : 0; |
| 353 | + _prevScore = validationScore; |
| 354 | + |
| 355 | + return (_count >= EarlyStoppingCriterionOptions.WindowSize); |
| 356 | + } |
| 357 | + } |
| 358 | +} |
0 commit comments