Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 02e85cc

Browse files
tannergoodingsfilipi
authored andcommittedOct 5, 2018
Fix MatchNumberWithTolerance to better compare floating-point values (#1145)
* Fix MatchNumberWithTolerance to better compare floating-point values * Updating CheckEqualityFromPathsCore to allow a tolerance match on Windows
1 parent d517589 commit 02e85cc

File tree

4 files changed

+48
-33
lines changed

4 files changed

+48
-33
lines changed
 

‎test/Microsoft.ML.Predictor.Tests/TestPredictors.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ public void EarlyStoppingTest()
161161
[TestCategory("Logistic Regression")]
162162
public void MulticlassLRTest()
163163
{
164-
RunOneAllTests(TestLearners.multiclassLogisticRegression, TestDatasets.iris, precision: 10_000);
164+
RunOneAllTests(TestLearners.multiclassLogisticRegression, TestDatasets.iris, digitsOfPrecision: 4);
165165
Done();
166166
}
167167

@@ -173,7 +173,7 @@ public void MulticlassLRTest()
173173
[TestCategory("Logistic Regression")]
174174
public void MulticlassLRNonNegativeTest()
175175
{
176-
RunOneAllTests(TestLearners.multiclassLogisticRegressionNonNegative, TestDatasets.iris, precision: 10_000);
176+
RunOneAllTests(TestLearners.multiclassLogisticRegressionNonNegative, TestDatasets.iris, digitsOfPrecision: 4);
177177
Done();
178178
}
179179

@@ -200,8 +200,8 @@ public void MulticlassTreeFeaturizedLRTest()
200200
{
201201
RunMTAThread(() =>
202202
{
203-
RunOneAllTests(TestLearners.multiclassLogisticRegression, TestDatasets.irisTreeFeaturized, precision: 10_000);
204-
RunOneAllTests(TestLearners.multiclassLogisticRegression, TestDatasets.irisTreeFeaturizedPermuted, precision: 10_000);
203+
RunOneAllTests(TestLearners.multiclassLogisticRegression, TestDatasets.irisTreeFeaturized, digitsOfPrecision: 4);
204+
RunOneAllTests(TestLearners.multiclassLogisticRegression, TestDatasets.irisTreeFeaturizedPermuted, digitsOfPrecision: 4);
205205
});
206206
Done();
207207
}

‎test/Microsoft.ML.TestFramework/BaseTestBaseline.cs

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ namespace Microsoft.ML.Runtime.RunTests
2323
/// </summary>
2424
public abstract partial class BaseTestBaseline : BaseTestClass
2525
{
26-
public const decimal Tolerance = 10_000_000;
26+
public const int DigitsOfPrecision = 7;
2727

2828
protected BaseTestBaseline(ITestOutputHelper output) : base(output)
2929
{
@@ -352,12 +352,12 @@ protected bool CheckEquality(string dir, string name, string nameBase = null)
352352
/// Check whether two files are same ignoring volatile differences (path, dates, times, etc).
353353
/// Returns true if the check passes.
354354
/// </summary>
355-
protected bool CheckEqualityNormalized(string dir, string name, string nameBase = null, decimal precision = Tolerance)
355+
protected bool CheckEqualityNormalized(string dir, string name, string nameBase = null, int digitsOfPrecision = DigitsOfPrecision)
356356
{
357-
return CheckEqualityCore(dir, name, nameBase ?? name, true, precision);
357+
return CheckEqualityCore(dir, name, nameBase ?? name, true, digitsOfPrecision);
358358
}
359359

360-
protected bool CheckEqualityCore(string dir, string name, string nameBase, bool normalize, decimal precision = Tolerance)
360+
protected bool CheckEqualityCore(string dir, string name, string nameBase, bool normalize, int digitsOfPrecision = DigitsOfPrecision)
361361
{
362362
Contracts.Assert(IsActive);
363363
Contracts.AssertValue(dir); // Can be empty.
@@ -384,7 +384,7 @@ protected bool CheckEqualityCore(string dir, string name, string nameBase, bool
384384
if (!CheckBaseFile(basePath))
385385
return false;
386386

387-
bool res = CheckEqualityFromPathsCore(relPath, basePath, outPath, precision: precision);
387+
bool res = CheckEqualityFromPathsCore(relPath, basePath, outPath, digitsOfPrecision: digitsOfPrecision);
388388

389389
// No need to keep the raw (unnormalized) output file.
390390
if (normalize && res)
@@ -501,7 +501,7 @@ protected bool CheckOutputIsSuffix(string basePath, string outPath, int skip = 0
501501
/// skipping the given number of lines on the output, and finding the corresponding line
502502
/// in the baseline.
503503
/// </summary>
504-
protected bool CheckEqualityNormalized(string dir, string name, string suffix, int skip, decimal precision = Tolerance)
504+
protected bool CheckEqualityNormalized(string dir, string name, string suffix, int skip, int digitsOfPrecision = DigitsOfPrecision)
505505
{
506506
Contracts.Assert(IsActive);
507507
Contracts.AssertValue(dir); // Can be empty.
@@ -522,7 +522,7 @@ protected bool CheckEqualityNormalized(string dir, string name, string suffix, i
522522
if (!CheckBaseFile(basePath))
523523
return false;
524524

525-
bool res = CheckEqualityFromPathsCore(relPath, basePath, outPath, skip, precision);
525+
bool res = CheckEqualityFromPathsCore(relPath, basePath, outPath, skip, digitsOfPrecision);
526526

527527
// No need to keep the raw (unnormalized) output file.
528528
if (res)
@@ -531,7 +531,7 @@ protected bool CheckEqualityNormalized(string dir, string name, string suffix, i
531531
return res;
532532
}
533533

534-
protected bool CheckEqualityFromPathsCore(string relPath, string basePath, string outPath, int skip = 0, decimal precision = Tolerance)
534+
protected bool CheckEqualityFromPathsCore(string relPath, string basePath, string outPath, int skip = 0, int digitsOfPrecision = DigitsOfPrecision)
535535
{
536536
Contracts.Assert(skip >= 0);
537537

@@ -578,9 +578,7 @@ protected bool CheckEqualityFromPathsCore(string relPath, string basePath, strin
578578
}
579579

580580
count++;
581-
582-
if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
583-
GetNumbersFromFile(ref line1, ref line2, precision);
581+
GetNumbersFromFile(ref line1, ref line2, digitsOfPrecision);
584582

585583
if (line1 != line2)
586584
{
@@ -594,28 +592,45 @@ protected bool CheckEqualityFromPathsCore(string relPath, string basePath, strin
594592
}
595593
}
596594

597-
private static void GetNumbersFromFile(ref string firstString, ref string secondString, decimal precision)
595+
private static void GetNumbersFromFile(ref string firstString, ref string secondString, int digitsOfPrecision)
598596
{
599597
Regex _matchNumer = new Regex(@"\b[0-9]+\.?[0-9]*\b", RegexOptions.IgnoreCase | RegexOptions.Compiled);
600598
MatchCollection firstCollection = _matchNumer.Matches(firstString);
601599
MatchCollection secondCollection = _matchNumer.Matches(secondString);
602600

603-
MatchNumberWithTolerance(firstCollection, secondCollection, precision);
601+
MatchNumberWithTolerance(firstCollection, secondCollection, digitsOfPrecision);
604602
firstString = _matchNumer.Replace(firstString, "%Number%");
605603
secondString = _matchNumer.Replace(secondString, "%Number%");
606604
}
607605

608-
private static void MatchNumberWithTolerance(MatchCollection firstCollection, MatchCollection secondCollection, decimal precision)
606+
private static void MatchNumberWithTolerance(MatchCollection firstCollection, MatchCollection secondCollection, int digitsOfPrecision)
609607
{
610608
for (int i = 0; i < firstCollection.Count; i++)
611609
{
612-
decimal f1 = decimal.Parse(firstCollection[i].ToString());
613-
decimal f2 = decimal.Parse(secondCollection[i].ToString());
610+
double f1 = double.Parse(firstCollection[i].ToString());
611+
double f2 = double.Parse(secondCollection[i].ToString());
612+
613+
double allowedVariance = Math.Pow(10, -digitsOfPrecision);
614+
double delta = Round(f1, digitsOfPrecision) - Round(f2, digitsOfPrecision);
614615

615-
Assert.InRange(f1, f2 - (f2 / precision), f2 + (f2 / precision));
616+
Assert.InRange(delta, -allowedVariance, allowedVariance);
616617
}
617618
}
618619

620+
private static double Round(double value, int digitsOfPrecision)
621+
{
622+
if ((value == 0) || double.IsInfinity(value) || double.IsNaN(value))
623+
{
624+
return value;
625+
}
626+
627+
double absValue = Math.Abs(value);
628+
double integralDigitCount = Math.Floor(Math.Log10(absValue) + 1);
629+
630+
double scale = Math.Pow(10, integralDigitCount);
631+
return scale * Math.Round(value / scale, digitsOfPrecision);
632+
}
633+
619634
#if TOLERANCE_ENABLED
620635
// This corresponds to how much relative error is tolerable for a value of 0.
621636
const Float RelativeToleranceStepSize = (Float)0.001;

‎test/Microsoft.ML.TestFramework/BaseTestPredictorsMaml.cs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ public TestImpl(RunContextBase ctx) :
9090
/// <summary>
9191
/// Run the predictor with given args and check if it adds up
9292
/// </summary>
93-
protected void Run(RunContext ctx, decimal precision = Tolerance)
93+
protected void Run(RunContext ctx, int digitsOfPrecision = DigitsOfPrecision)
9494
{
9595
Contracts.Assert(IsActive);
9696
List<string> args = new List<string>();
@@ -164,7 +164,7 @@ protected void Run(RunContext ctx, decimal precision = Tolerance)
164164
}
165165
var consOutPath = ctx.StdoutPath();
166166
TestCore(ctx, ctx.Command.ToString(), runcmd);
167-
bool matched = consOutPath.CheckEqualityNormalized(precision);
167+
bool matched = consOutPath.CheckEqualityNormalized(digitsOfPrecision);
168168

169169
if (modelPath != null && (ctx.Summary || ctx.SaveAsIni))
170170
{
@@ -190,7 +190,7 @@ protected void Run(RunContext ctx, decimal precision = Tolerance)
190190
}
191191

192192
MainForTest(Env, LogWriter, str);
193-
files.ForEach(file => CheckEqualityNormalized(dir, file, precision: precision));
193+
files.ForEach(file => CheckEqualityNormalized(dir, file, digitsOfPrecision: digitsOfPrecision));
194194
}
195195

196196
if (ctx.Command == Cmd.Train || ctx.Command == Cmd.Test || ctx.ExpectedToFail)
@@ -351,11 +351,11 @@ protected void RunAllTests(
351351
/// Run TrainTest, CV, and TrainSaveTest for a single predictor on a single dataset.
352352
/// </summary>
353353
protected void RunOneAllTests(PredictorAndArgs predictor, TestDataset dataset,
354-
string[] extraSettings = null, string extraTag = "", bool summary = false, decimal precision = Tolerance)
354+
string[] extraSettings = null, string extraTag = "", bool summary = false, int digitsOfPrecision = DigitsOfPrecision)
355355
{
356356
Contracts.Assert(IsActive);
357-
Run_TrainTest(predictor, dataset, extraSettings, extraTag, summary: summary, precision: precision);
358-
Run_CV(predictor, dataset, extraSettings, extraTag, useTest: true, precision: precision);
357+
Run_TrainTest(predictor, dataset, extraSettings, extraTag, summary: summary, digitsOfPrecision: digitsOfPrecision);
358+
Run_CV(predictor, dataset, extraSettings, extraTag, useTest: true, digitsOfPrecision: digitsOfPrecision);
359359
}
360360

361361
/// <summary>
@@ -383,10 +383,10 @@ protected RunContext Run_Train(PredictorAndArgs predictor, TestDataset dataset,
383383
/// Run a train-test unit test
384384
/// </summary>
385385
protected void Run_TrainTest(PredictorAndArgs predictor, TestDataset dataset,
386-
string[] extraSettings = null, string extraTag = "", bool expectFailure = false, bool summary = false, bool saveAsIni = false, decimal precision = Tolerance)
386+
string[] extraSettings = null, string extraTag = "", bool expectFailure = false, bool summary = false, bool saveAsIni = false, int digitsOfPrecision = DigitsOfPrecision)
387387
{
388388
RunContext ctx = new RunContext(this, Cmd.TrainTest, predictor, dataset, extraSettings, extraTag, expectFailure: expectFailure, summary: summary, saveAsIni: saveAsIni);
389-
Run(ctx, precision);
389+
Run(ctx, digitsOfPrecision);
390390
}
391391

392392
// REVIEW: Remove TrainSaveTest and supporting code.
@@ -421,7 +421,7 @@ protected void Run_Test(PredictorAndArgs predictor, TestDataset dataset, string
421421
/// <paramref name="useTest"/> is set.
422422
/// </summary>
423423
protected void Run_CV(PredictorAndArgs predictor, TestDataset dataset,
424-
string[] extraSettings = null, string extraTag = "", bool useTest = false, decimal precision = Tolerance)
424+
string[] extraSettings = null, string extraTag = "", bool useTest = false, int digitsOfPrecision = DigitsOfPrecision)
425425
{
426426
if (useTest)
427427
{
@@ -431,7 +431,7 @@ protected void Run_CV(PredictorAndArgs predictor, TestDataset dataset,
431431
dataset.trainFilename = dataset.testFilename;
432432
}
433433
RunContext cvCtx = new RunContext(this, Cmd.CV, predictor, dataset, extraSettings, extraTag);
434-
Run(cvCtx, precision);
434+
Run(cvCtx, digitsOfPrecision);
435435
}
436436

437437
/// <summary>

‎test/Microsoft.ML.TestFramework/TestCommandBase.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,10 @@ public bool CheckEquality()
7070
return _testCmd.CheckEquality(_dir, _name);
7171
}
7272

73-
public bool CheckEqualityNormalized(decimal precision = Tolerance)
73+
public bool CheckEqualityNormalized(int digitsOfPrecision = DigitsOfPrecision)
7474
{
7575
Contracts.Assert(CanBeBaselined);
76-
return _testCmd.CheckEqualityNormalized(_dir, _name, precision: precision);
76+
return _testCmd.CheckEqualityNormalized(_dir, _name, digitsOfPrecision: digitsOfPrecision);
7777
}
7878

7979
public string ArgStr(string name)

0 commit comments

Comments
 (0)
Please sign in to comment.