diff --git a/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs b/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs index 6ebfa5fc05..b2e6095d7f 100644 --- a/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs +++ b/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs @@ -161,7 +161,7 @@ public void EarlyStoppingTest() [TestCategory("Logistic Regression")] public void MulticlassLRTest() { - RunOneAllTests(TestLearners.multiclassLogisticRegression, TestDatasets.iris, precision: 10_000); + RunOneAllTests(TestLearners.multiclassLogisticRegression, TestDatasets.iris, digitsOfPrecision: 4); Done(); } @@ -173,7 +173,7 @@ public void MulticlassLRTest() [TestCategory("Logistic Regression")] public void MulticlassLRNonNegativeTest() { - RunOneAllTests(TestLearners.multiclassLogisticRegressionNonNegative, TestDatasets.iris, precision: 10_000); + RunOneAllTests(TestLearners.multiclassLogisticRegressionNonNegative, TestDatasets.iris, digitsOfPrecision: 4); Done(); } @@ -200,8 +200,8 @@ public void MulticlassTreeFeaturizedLRTest() { RunMTAThread(() => { - RunOneAllTests(TestLearners.multiclassLogisticRegression, TestDatasets.irisTreeFeaturized, precision: 10_000); - RunOneAllTests(TestLearners.multiclassLogisticRegression, TestDatasets.irisTreeFeaturizedPermuted, precision: 10_000); + RunOneAllTests(TestLearners.multiclassLogisticRegression, TestDatasets.irisTreeFeaturized, digitsOfPrecision: 4); + RunOneAllTests(TestLearners.multiclassLogisticRegression, TestDatasets.irisTreeFeaturizedPermuted, digitsOfPrecision: 4); }); Done(); } diff --git a/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs b/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs index ad077a85d3..fe05ce5b8a 100644 --- a/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs +++ b/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs @@ -23,7 +23,7 @@ namespace Microsoft.ML.Runtime.RunTests /// public abstract partial class BaseTestBaseline : BaseTestClass { - public const decimal Tolerance = 10_000_000; + public const int DigitsOfPrecision = 7; protected BaseTestBaseline(ITestOutputHelper output) : base(output) { @@ -352,12 +352,12 @@ protected bool CheckEquality(string dir, string name, string nameBase = null) /// Check whether two files are same ignoring volatile differences (path, dates, times, etc). /// Returns true if the check passes. /// - protected bool CheckEqualityNormalized(string dir, string name, string nameBase = null, decimal precision = Tolerance) + protected bool CheckEqualityNormalized(string dir, string name, string nameBase = null, int digitsOfPrecision = DigitsOfPrecision) { - return CheckEqualityCore(dir, name, nameBase ?? name, true, precision); + return CheckEqualityCore(dir, name, nameBase ?? name, true, digitsOfPrecision); } - protected bool CheckEqualityCore(string dir, string name, string nameBase, bool normalize, decimal precision = Tolerance) + protected bool CheckEqualityCore(string dir, string name, string nameBase, bool normalize, int digitsOfPrecision = DigitsOfPrecision) { Contracts.Assert(IsActive); Contracts.AssertValue(dir); // Can be empty. @@ -384,7 +384,7 @@ protected bool CheckEqualityCore(string dir, string name, string nameBase, bool if (!CheckBaseFile(basePath)) return false; - bool res = CheckEqualityFromPathsCore(relPath, basePath, outPath, precision: precision); + bool res = CheckEqualityFromPathsCore(relPath, basePath, outPath, digitsOfPrecision: digitsOfPrecision); // No need to keep the raw (unnormalized) output file. if (normalize && res) @@ -501,7 +501,7 @@ protected bool CheckOutputIsSuffix(string basePath, string outPath, int skip = 0 /// skipping the given number of lines on the output, and finding the corresponding line /// in the baseline. /// - protected bool CheckEqualityNormalized(string dir, string name, string suffix, int skip, decimal precision = Tolerance) + protected bool CheckEqualityNormalized(string dir, string name, string suffix, int skip, int digitsOfPrecision = DigitsOfPrecision) { Contracts.Assert(IsActive); Contracts.AssertValue(dir); // Can be empty. @@ -522,7 +522,7 @@ protected bool CheckEqualityNormalized(string dir, string name, string suffix, i if (!CheckBaseFile(basePath)) return false; - bool res = CheckEqualityFromPathsCore(relPath, basePath, outPath, skip, precision); + bool res = CheckEqualityFromPathsCore(relPath, basePath, outPath, skip, digitsOfPrecision); // No need to keep the raw (unnormalized) output file. if (res) @@ -531,7 +531,7 @@ protected bool CheckEqualityNormalized(string dir, string name, string suffix, i return res; } - protected bool CheckEqualityFromPathsCore(string relPath, string basePath, string outPath, int skip = 0, decimal precision = Tolerance) + protected bool CheckEqualityFromPathsCore(string relPath, string basePath, string outPath, int skip = 0, int digitsOfPrecision = DigitsOfPrecision) { Contracts.Assert(skip >= 0); @@ -578,9 +578,7 @@ protected bool CheckEqualityFromPathsCore(string relPath, string basePath, strin } count++; - - if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) - GetNumbersFromFile(ref line1, ref line2, precision); + GetNumbersFromFile(ref line1, ref line2, digitsOfPrecision); if (line1 != line2) { @@ -594,28 +592,45 @@ protected bool CheckEqualityFromPathsCore(string relPath, string basePath, strin } } - private static void GetNumbersFromFile(ref string firstString, ref string secondString, decimal precision) + private static void GetNumbersFromFile(ref string firstString, ref string secondString, int digitsOfPrecision) { Regex _matchNumer = new Regex(@"\b[0-9]+\.?[0-9]*\b", RegexOptions.IgnoreCase | RegexOptions.Compiled); MatchCollection firstCollection = _matchNumer.Matches(firstString); MatchCollection secondCollection = _matchNumer.Matches(secondString); - MatchNumberWithTolerance(firstCollection, secondCollection, precision); + MatchNumberWithTolerance(firstCollection, secondCollection, digitsOfPrecision); firstString = _matchNumer.Replace(firstString, "%Number%"); secondString = _matchNumer.Replace(secondString, "%Number%"); } - private static void MatchNumberWithTolerance(MatchCollection firstCollection, MatchCollection secondCollection, decimal precision) + private static void MatchNumberWithTolerance(MatchCollection firstCollection, MatchCollection secondCollection, int digitsOfPrecision) { for (int i = 0; i < firstCollection.Count; i++) { - decimal f1 = decimal.Parse(firstCollection[i].ToString()); - decimal f2 = decimal.Parse(secondCollection[i].ToString()); + double f1 = double.Parse(firstCollection[i].ToString()); + double f2 = double.Parse(secondCollection[i].ToString()); + + double allowedVariance = Math.Pow(10, -digitsOfPrecision); + double delta = Round(f1, digitsOfPrecision) - Round(f2, digitsOfPrecision); - Assert.InRange(f1, f2 - (f2 / precision), f2 + (f2 / precision)); + Assert.InRange(delta, -allowedVariance, allowedVariance); } } + private static double Round(double value, int digitsOfPrecision) + { + if ((value == 0) || double.IsInfinity(value) || double.IsNaN(value)) + { + return value; + } + + double absValue = Math.Abs(value); + double integralDigitCount = Math.Floor(Math.Log10(absValue) + 1); + + double scale = Math.Pow(10, integralDigitCount); + return scale * Math.Round(value / scale, digitsOfPrecision); + } + #if TOLERANCE_ENABLED // This corresponds to how much relative error is tolerable for a value of 0. const Float RelativeToleranceStepSize = (Float)0.001; diff --git a/test/Microsoft.ML.TestFramework/BaseTestPredictorsMaml.cs b/test/Microsoft.ML.TestFramework/BaseTestPredictorsMaml.cs index 63aa38239b..39a95abd1a 100644 --- a/test/Microsoft.ML.TestFramework/BaseTestPredictorsMaml.cs +++ b/test/Microsoft.ML.TestFramework/BaseTestPredictorsMaml.cs @@ -90,7 +90,7 @@ public TestImpl(RunContextBase ctx) : /// /// Run the predictor with given args and check if it adds up /// - protected void Run(RunContext ctx, decimal precision = Tolerance) + protected void Run(RunContext ctx, int digitsOfPrecision = DigitsOfPrecision) { Contracts.Assert(IsActive); List args = new List(); @@ -164,7 +164,7 @@ protected void Run(RunContext ctx, decimal precision = Tolerance) } var consOutPath = ctx.StdoutPath(); TestCore(ctx, ctx.Command.ToString(), runcmd); - bool matched = consOutPath.CheckEqualityNormalized(precision); + bool matched = consOutPath.CheckEqualityNormalized(digitsOfPrecision); if (modelPath != null && (ctx.Summary || ctx.SaveAsIni)) { @@ -190,7 +190,7 @@ protected void Run(RunContext ctx, decimal precision = Tolerance) } MainForTest(Env, LogWriter, str); - files.ForEach(file => CheckEqualityNormalized(dir, file, precision: precision)); + files.ForEach(file => CheckEqualityNormalized(dir, file, digitsOfPrecision: digitsOfPrecision)); } if (ctx.Command == Cmd.Train || ctx.Command == Cmd.Test || ctx.ExpectedToFail) @@ -351,11 +351,11 @@ protected void RunAllTests( /// Run TrainTest, CV, and TrainSaveTest for a single predictor on a single dataset. /// protected void RunOneAllTests(PredictorAndArgs predictor, TestDataset dataset, - string[] extraSettings = null, string extraTag = "", bool summary = false, decimal precision = Tolerance) + string[] extraSettings = null, string extraTag = "", bool summary = false, int digitsOfPrecision = DigitsOfPrecision) { Contracts.Assert(IsActive); - Run_TrainTest(predictor, dataset, extraSettings, extraTag, summary: summary, precision: precision); - Run_CV(predictor, dataset, extraSettings, extraTag, useTest: true, precision: precision); + Run_TrainTest(predictor, dataset, extraSettings, extraTag, summary: summary, digitsOfPrecision: digitsOfPrecision); + Run_CV(predictor, dataset, extraSettings, extraTag, useTest: true, digitsOfPrecision: digitsOfPrecision); } /// @@ -383,10 +383,10 @@ protected RunContext Run_Train(PredictorAndArgs predictor, TestDataset dataset, /// Run a train-test unit test /// protected void Run_TrainTest(PredictorAndArgs predictor, TestDataset dataset, - string[] extraSettings = null, string extraTag = "", bool expectFailure = false, bool summary = false, bool saveAsIni = false, decimal precision = Tolerance) + string[] extraSettings = null, string extraTag = "", bool expectFailure = false, bool summary = false, bool saveAsIni = false, int digitsOfPrecision = DigitsOfPrecision) { RunContext ctx = new RunContext(this, Cmd.TrainTest, predictor, dataset, extraSettings, extraTag, expectFailure: expectFailure, summary: summary, saveAsIni: saveAsIni); - Run(ctx, precision); + Run(ctx, digitsOfPrecision); } // REVIEW: Remove TrainSaveTest and supporting code. @@ -421,7 +421,7 @@ protected void Run_Test(PredictorAndArgs predictor, TestDataset dataset, string /// is set. /// protected void Run_CV(PredictorAndArgs predictor, TestDataset dataset, - string[] extraSettings = null, string extraTag = "", bool useTest = false, decimal precision = Tolerance) + string[] extraSettings = null, string extraTag = "", bool useTest = false, int digitsOfPrecision = DigitsOfPrecision) { if (useTest) { @@ -431,7 +431,7 @@ protected void Run_CV(PredictorAndArgs predictor, TestDataset dataset, dataset.trainFilename = dataset.testFilename; } RunContext cvCtx = new RunContext(this, Cmd.CV, predictor, dataset, extraSettings, extraTag); - Run(cvCtx, precision); + Run(cvCtx, digitsOfPrecision); } /// diff --git a/test/Microsoft.ML.TestFramework/TestCommandBase.cs b/test/Microsoft.ML.TestFramework/TestCommandBase.cs index ae539f9992..ad5ae609a4 100644 --- a/test/Microsoft.ML.TestFramework/TestCommandBase.cs +++ b/test/Microsoft.ML.TestFramework/TestCommandBase.cs @@ -70,10 +70,10 @@ public bool CheckEquality() return _testCmd.CheckEquality(_dir, _name); } - public bool CheckEqualityNormalized(decimal precision = Tolerance) + public bool CheckEqualityNormalized(int digitsOfPrecision = DigitsOfPrecision) { Contracts.Assert(CanBeBaselined); - return _testCmd.CheckEqualityNormalized(_dir, _name, precision: precision); + return _testCmd.CheckEqualityNormalized(_dir, _name, digitsOfPrecision: digitsOfPrecision); } public string ArgStr(string name)