diff --git a/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs b/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs
index 6ebfa5fc05..b2e6095d7f 100644
--- a/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs
+++ b/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs
@@ -161,7 +161,7 @@ public void EarlyStoppingTest()
[TestCategory("Logistic Regression")]
public void MulticlassLRTest()
{
- RunOneAllTests(TestLearners.multiclassLogisticRegression, TestDatasets.iris, precision: 10_000);
+ RunOneAllTests(TestLearners.multiclassLogisticRegression, TestDatasets.iris, digitsOfPrecision: 4);
Done();
}
@@ -173,7 +173,7 @@ public void MulticlassLRTest()
[TestCategory("Logistic Regression")]
public void MulticlassLRNonNegativeTest()
{
- RunOneAllTests(TestLearners.multiclassLogisticRegressionNonNegative, TestDatasets.iris, precision: 10_000);
+ RunOneAllTests(TestLearners.multiclassLogisticRegressionNonNegative, TestDatasets.iris, digitsOfPrecision: 4);
Done();
}
@@ -200,8 +200,8 @@ public void MulticlassTreeFeaturizedLRTest()
{
RunMTAThread(() =>
{
- RunOneAllTests(TestLearners.multiclassLogisticRegression, TestDatasets.irisTreeFeaturized, precision: 10_000);
- RunOneAllTests(TestLearners.multiclassLogisticRegression, TestDatasets.irisTreeFeaturizedPermuted, precision: 10_000);
+ RunOneAllTests(TestLearners.multiclassLogisticRegression, TestDatasets.irisTreeFeaturized, digitsOfPrecision: 4);
+ RunOneAllTests(TestLearners.multiclassLogisticRegression, TestDatasets.irisTreeFeaturizedPermuted, digitsOfPrecision: 4);
});
Done();
}
diff --git a/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs b/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs
index ad077a85d3..fe05ce5b8a 100644
--- a/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs
+++ b/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs
@@ -23,7 +23,7 @@ namespace Microsoft.ML.Runtime.RunTests
///
public abstract partial class BaseTestBaseline : BaseTestClass
{
- public const decimal Tolerance = 10_000_000;
+ public const int DigitsOfPrecision = 7;
protected BaseTestBaseline(ITestOutputHelper output) : base(output)
{
@@ -352,12 +352,12 @@ protected bool CheckEquality(string dir, string name, string nameBase = null)
/// Check whether two files are same ignoring volatile differences (path, dates, times, etc).
/// Returns true if the check passes.
///
- protected bool CheckEqualityNormalized(string dir, string name, string nameBase = null, decimal precision = Tolerance)
+ protected bool CheckEqualityNormalized(string dir, string name, string nameBase = null, int digitsOfPrecision = DigitsOfPrecision)
{
- return CheckEqualityCore(dir, name, nameBase ?? name, true, precision);
+ return CheckEqualityCore(dir, name, nameBase ?? name, true, digitsOfPrecision);
}
- protected bool CheckEqualityCore(string dir, string name, string nameBase, bool normalize, decimal precision = Tolerance)
+ protected bool CheckEqualityCore(string dir, string name, string nameBase, bool normalize, int digitsOfPrecision = DigitsOfPrecision)
{
Contracts.Assert(IsActive);
Contracts.AssertValue(dir); // Can be empty.
@@ -384,7 +384,7 @@ protected bool CheckEqualityCore(string dir, string name, string nameBase, bool
if (!CheckBaseFile(basePath))
return false;
- bool res = CheckEqualityFromPathsCore(relPath, basePath, outPath, precision: precision);
+ bool res = CheckEqualityFromPathsCore(relPath, basePath, outPath, digitsOfPrecision: digitsOfPrecision);
// No need to keep the raw (unnormalized) output file.
if (normalize && res)
@@ -501,7 +501,7 @@ protected bool CheckOutputIsSuffix(string basePath, string outPath, int skip = 0
/// skipping the given number of lines on the output, and finding the corresponding line
/// in the baseline.
///
- protected bool CheckEqualityNormalized(string dir, string name, string suffix, int skip, decimal precision = Tolerance)
+ protected bool CheckEqualityNormalized(string dir, string name, string suffix, int skip, int digitsOfPrecision = DigitsOfPrecision)
{
Contracts.Assert(IsActive);
Contracts.AssertValue(dir); // Can be empty.
@@ -522,7 +522,7 @@ protected bool CheckEqualityNormalized(string dir, string name, string suffix, i
if (!CheckBaseFile(basePath))
return false;
- bool res = CheckEqualityFromPathsCore(relPath, basePath, outPath, skip, precision);
+ bool res = CheckEqualityFromPathsCore(relPath, basePath, outPath, skip, digitsOfPrecision);
// No need to keep the raw (unnormalized) output file.
if (res)
@@ -531,7 +531,7 @@ protected bool CheckEqualityNormalized(string dir, string name, string suffix, i
return res;
}
- protected bool CheckEqualityFromPathsCore(string relPath, string basePath, string outPath, int skip = 0, decimal precision = Tolerance)
+ protected bool CheckEqualityFromPathsCore(string relPath, string basePath, string outPath, int skip = 0, int digitsOfPrecision = DigitsOfPrecision)
{
Contracts.Assert(skip >= 0);
@@ -578,9 +578,7 @@ protected bool CheckEqualityFromPathsCore(string relPath, string basePath, strin
}
count++;
-
- if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
- GetNumbersFromFile(ref line1, ref line2, precision);
+ GetNumbersFromFile(ref line1, ref line2, digitsOfPrecision);
if (line1 != line2)
{
@@ -594,28 +592,45 @@ protected bool CheckEqualityFromPathsCore(string relPath, string basePath, strin
}
}
- private static void GetNumbersFromFile(ref string firstString, ref string secondString, decimal precision)
+ private static void GetNumbersFromFile(ref string firstString, ref string secondString, int digitsOfPrecision)
{
Regex _matchNumer = new Regex(@"\b[0-9]+\.?[0-9]*\b", RegexOptions.IgnoreCase | RegexOptions.Compiled);
MatchCollection firstCollection = _matchNumer.Matches(firstString);
MatchCollection secondCollection = _matchNumer.Matches(secondString);
- MatchNumberWithTolerance(firstCollection, secondCollection, precision);
+ MatchNumberWithTolerance(firstCollection, secondCollection, digitsOfPrecision);
firstString = _matchNumer.Replace(firstString, "%Number%");
secondString = _matchNumer.Replace(secondString, "%Number%");
}
- private static void MatchNumberWithTolerance(MatchCollection firstCollection, MatchCollection secondCollection, decimal precision)
+ private static void MatchNumberWithTolerance(MatchCollection firstCollection, MatchCollection secondCollection, int digitsOfPrecision)
{
for (int i = 0; i < firstCollection.Count; i++)
{
- decimal f1 = decimal.Parse(firstCollection[i].ToString());
- decimal f2 = decimal.Parse(secondCollection[i].ToString());
+ double f1 = double.Parse(firstCollection[i].ToString());
+ double f2 = double.Parse(secondCollection[i].ToString());
+
+ double allowedVariance = Math.Pow(10, -digitsOfPrecision);
+ double delta = Round(f1, digitsOfPrecision) - Round(f2, digitsOfPrecision);
- Assert.InRange(f1, f2 - (f2 / precision), f2 + (f2 / precision));
+ Assert.InRange(delta, -allowedVariance, allowedVariance);
}
}
+ private static double Round(double value, int digitsOfPrecision)
+ {
+ if ((value == 0) || double.IsInfinity(value) || double.IsNaN(value))
+ {
+ return value;
+ }
+
+ double absValue = Math.Abs(value);
+ double integralDigitCount = Math.Floor(Math.Log10(absValue) + 1);
+
+ double scale = Math.Pow(10, integralDigitCount);
+ return scale * Math.Round(value / scale, digitsOfPrecision);
+ }
+
#if TOLERANCE_ENABLED
// This corresponds to how much relative error is tolerable for a value of 0.
const Float RelativeToleranceStepSize = (Float)0.001;
diff --git a/test/Microsoft.ML.TestFramework/BaseTestPredictorsMaml.cs b/test/Microsoft.ML.TestFramework/BaseTestPredictorsMaml.cs
index 63aa38239b..39a95abd1a 100644
--- a/test/Microsoft.ML.TestFramework/BaseTestPredictorsMaml.cs
+++ b/test/Microsoft.ML.TestFramework/BaseTestPredictorsMaml.cs
@@ -90,7 +90,7 @@ public TestImpl(RunContextBase ctx) :
///
/// Run the predictor with given args and check if it adds up
///
- protected void Run(RunContext ctx, decimal precision = Tolerance)
+ protected void Run(RunContext ctx, int digitsOfPrecision = DigitsOfPrecision)
{
Contracts.Assert(IsActive);
List args = new List();
@@ -164,7 +164,7 @@ protected void Run(RunContext ctx, decimal precision = Tolerance)
}
var consOutPath = ctx.StdoutPath();
TestCore(ctx, ctx.Command.ToString(), runcmd);
- bool matched = consOutPath.CheckEqualityNormalized(precision);
+ bool matched = consOutPath.CheckEqualityNormalized(digitsOfPrecision);
if (modelPath != null && (ctx.Summary || ctx.SaveAsIni))
{
@@ -190,7 +190,7 @@ protected void Run(RunContext ctx, decimal precision = Tolerance)
}
MainForTest(Env, LogWriter, str);
- files.ForEach(file => CheckEqualityNormalized(dir, file, precision: precision));
+ files.ForEach(file => CheckEqualityNormalized(dir, file, digitsOfPrecision: digitsOfPrecision));
}
if (ctx.Command == Cmd.Train || ctx.Command == Cmd.Test || ctx.ExpectedToFail)
@@ -351,11 +351,11 @@ protected void RunAllTests(
/// Run TrainTest, CV, and TrainSaveTest for a single predictor on a single dataset.
///
protected void RunOneAllTests(PredictorAndArgs predictor, TestDataset dataset,
- string[] extraSettings = null, string extraTag = "", bool summary = false, decimal precision = Tolerance)
+ string[] extraSettings = null, string extraTag = "", bool summary = false, int digitsOfPrecision = DigitsOfPrecision)
{
Contracts.Assert(IsActive);
- Run_TrainTest(predictor, dataset, extraSettings, extraTag, summary: summary, precision: precision);
- Run_CV(predictor, dataset, extraSettings, extraTag, useTest: true, precision: precision);
+ Run_TrainTest(predictor, dataset, extraSettings, extraTag, summary: summary, digitsOfPrecision: digitsOfPrecision);
+ Run_CV(predictor, dataset, extraSettings, extraTag, useTest: true, digitsOfPrecision: digitsOfPrecision);
}
///
@@ -383,10 +383,10 @@ protected RunContext Run_Train(PredictorAndArgs predictor, TestDataset dataset,
/// Run a train-test unit test
///
protected void Run_TrainTest(PredictorAndArgs predictor, TestDataset dataset,
- string[] extraSettings = null, string extraTag = "", bool expectFailure = false, bool summary = false, bool saveAsIni = false, decimal precision = Tolerance)
+ string[] extraSettings = null, string extraTag = "", bool expectFailure = false, bool summary = false, bool saveAsIni = false, int digitsOfPrecision = DigitsOfPrecision)
{
RunContext ctx = new RunContext(this, Cmd.TrainTest, predictor, dataset, extraSettings, extraTag, expectFailure: expectFailure, summary: summary, saveAsIni: saveAsIni);
- Run(ctx, precision);
+ Run(ctx, digitsOfPrecision);
}
// REVIEW: Remove TrainSaveTest and supporting code.
@@ -421,7 +421,7 @@ protected void Run_Test(PredictorAndArgs predictor, TestDataset dataset, string
/// is set.
///
protected void Run_CV(PredictorAndArgs predictor, TestDataset dataset,
- string[] extraSettings = null, string extraTag = "", bool useTest = false, decimal precision = Tolerance)
+ string[] extraSettings = null, string extraTag = "", bool useTest = false, int digitsOfPrecision = DigitsOfPrecision)
{
if (useTest)
{
@@ -431,7 +431,7 @@ protected void Run_CV(PredictorAndArgs predictor, TestDataset dataset,
dataset.trainFilename = dataset.testFilename;
}
RunContext cvCtx = new RunContext(this, Cmd.CV, predictor, dataset, extraSettings, extraTag);
- Run(cvCtx, precision);
+ Run(cvCtx, digitsOfPrecision);
}
///
diff --git a/test/Microsoft.ML.TestFramework/TestCommandBase.cs b/test/Microsoft.ML.TestFramework/TestCommandBase.cs
index ae539f9992..ad5ae609a4 100644
--- a/test/Microsoft.ML.TestFramework/TestCommandBase.cs
+++ b/test/Microsoft.ML.TestFramework/TestCommandBase.cs
@@ -70,10 +70,10 @@ public bool CheckEquality()
return _testCmd.CheckEquality(_dir, _name);
}
- public bool CheckEqualityNormalized(decimal precision = Tolerance)
+ public bool CheckEqualityNormalized(int digitsOfPrecision = DigitsOfPrecision)
{
Contracts.Assert(CanBeBaselined);
- return _testCmd.CheckEqualityNormalized(_dir, _name, precision: precision);
+ return _testCmd.CheckEqualityNormalized(_dir, _name, digitsOfPrecision: digitsOfPrecision);
}
public string ArgStr(string name)