-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Adding training statistics for LR in the HAL learners package. #1392
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
2d5fbf6
68dd4a6
e8fede2
f649287
06f9704
96127d0
3831f5d
c638cbd
dd9524e
e540d63
2752b60
fe29307
f0b4707
5386a8c
fb897ed
377b462
89301b4
46316ea
c8d060a
737d173
39ca55e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,87 @@ | ||||||||
// Licensed to the .NET Foundation under one or more agreements. | ||||||||
// The .NET Foundation licenses this file to you under the MIT license. | ||||||||
// See the LICENSE file in the project root for more information. | ||||||||
|
||||||||
using Microsoft.ML.Runtime.Data; | ||||||||
using Microsoft.ML.Runtime.Internal.Utilities; | ||||||||
using Microsoft.ML.Trainers.HalLearners; | ||||||||
using System; | ||||||||
|
||||||||
namespace Microsoft.ML.Runtime.Learners | ||||||||
{ | ||||||||
using Mkl = OlsLinearRegressionTrainer.Mkl; | ||||||||
|
||||||||
/// <include file='doc.xml' path='doc/members/member[@name="LBFGS"]/*' /> | ||||||||
/// <include file='doc.xml' path='docs/members/example[@name="LogisticRegressionBinaryClassifier"]/*' /> | ||||||||
public static class LogisticRegressionTrainingStats | ||||||||
{ | ||||||||
|
||||||||
public static void ComputeExtendedTrainingStatistics(this LinearBinaryPredictor model, IChannel ch, float l2Weight = LogisticRegression.Arguments.Defaults.L2Weight) | ||||||||
{ | ||||||||
Contracts.AssertValue(ch); | ||||||||
Contracts.AssertValue(model.Statistics, $"Training Statistics can get generated after training finishes. Train with setting: ShowTrainigStats set to true."); | ||||||||
Contracts.Assert(l2Weight > 0); | ||||||||
|
||||||||
int numSelectedParams = model.Statistics.ParametersCount; | ||||||||
|
||||||||
// Apply Cholesky Decomposition to find the inverse of the Hessian. | ||||||||
Double[] invHessian = null; | ||||||||
try | ||||||||
{ | ||||||||
// First, find the Cholesky decomposition LL' of the Hessian. | ||||||||
Mkl.Pptrf(Mkl.Layout.RowMajor, Mkl.UpLo.Lo, numSelectedParams, model.Statistics.Hessian); | ||||||||
// Note that hessian is already modified at this point. It is no longer the original Hessian, | ||||||||
// but instead represents the Cholesky decomposition L. | ||||||||
// Also note that the following routine is supposed to consume the Cholesky decomposition L instead | ||||||||
// of the original information matrix. | ||||||||
Mkl.Pptri(Mkl.Layout.RowMajor, Mkl.UpLo.Lo, numSelectedParams, model.Statistics.Hessian); | ||||||||
// At this point, hessian should contain the inverse of the original Hessian matrix. | ||||||||
// Swap hessian with invHessian to avoid confusion in the following context. | ||||||||
Utils.Swap(ref model.Statistics.Hessian, ref invHessian); | ||||||||
Contracts.Assert(model.Statistics.Hessian == null); | ||||||||
} | ||||||||
catch (DllNotFoundException) | ||||||||
{ | ||||||||
throw ch.ExceptNotSupp("The MKL library (MklImports.dll) or one of its dependencies is missing."); | ||||||||
} | ||||||||
|
||||||||
float[] stdErrorValues = new float[numSelectedParams]; | ||||||||
stdErrorValues[0] = (float)Math.Sqrt(invHessian[0]); | ||||||||
|
||||||||
for (int i = 1; i < numSelectedParams; i++) | ||||||||
{ | ||||||||
// Initialize with inverse Hessian. | ||||||||
stdErrorValues[i] = (Single)invHessian[i * (i + 1) / 2 + i]; | ||||||||
} | ||||||||
|
||||||||
if (l2Weight > 0) | ||||||||
{ | ||||||||
// Iterate through all entries of inverse Hessian to make adjustment to variance. | ||||||||
// A discussion on ridge regularized LR coefficient covariance matrix can be found here: | ||||||||
// http://www.ncbi.nlm.nih.gov/pmc/articles/PMC3228544/ | ||||||||
|
// http://www.ncbi.nlm.nih.gov/pmc/articles/PMC3228544/ | |
// http://www.aloki.hu/pdf/0402_171179.pdf (Equations 11 and 25) | |
``` #Resolved |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// http://www.inf.unibz.it/dis/teaching/DWDM/project2010/LogisticRegression.pdf | |
// http://www.inf.unibz.it/dis/teaching/DWDM/project2010/LogisticRegression.pdf (Section "Significance testing in ridge logistic regression") | |
``` #Resolved |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line doesn't make a lot of sense to me. The code first compute \sigma=sqrt(x+\lambda)
and them do something like \sigma + \lambda * \sigma^2
. #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -92,14 +92,14 @@ public abstract class ArgumentsBase : LearnerInputBaseWithWeight | |
[Argument(ArgumentType.AtMostOnce, HelpText = "Enforce non-negative weights", ShortName = "nn", SortOrder = 90)] | ||
public bool EnforceNonNegativity = Defaults.EnforceNonNegativity; | ||
|
||
internal static class Defaults | ||
public static class Defaults | ||
|
||
{ | ||
internal const float L2Weight = 1; | ||
internal const float L1Weight = 1; | ||
internal const float OptTol = 1e-7f; | ||
internal const int MemorySize = 20; | ||
internal const int MaxIterations = int.MaxValue; | ||
internal const bool EnforceNonNegativity = false; | ||
public const float L2Weight = 1; | ||
public const float L1Weight = 1; | ||
public const float OptTol = 1e-7f; | ||
public const int MemorySize = 20; | ||
public const int MaxIterations = int.MaxValue; | ||
public const bool EnforceNonNegativity = false; | ||
} | ||
} | ||
|
||
|
@@ -258,7 +258,7 @@ private static TArgs ArgsInit(string featureColumn, SchemaShape.Column labelColu | |
} | ||
|
||
protected virtual int ClassCount => 1; | ||
protected int BiasCount => ClassCount; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
revert #Resolved |
||
public int BiasCount => ClassCount; | ||
protected int WeightCount => ClassCount * NumFeatures; | ||
protected virtual Optimizer InitializeOptimizer(IChannel ch, FloatLabelCursor.Factory cursorFactory, | ||
out VBuffer<float> init, out ITerminationCriterion terminationCriterion) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
|
||
using Microsoft.ML.Core.Data; | ||
using Microsoft.ML.Runtime.Data; | ||
using Microsoft.ML.Runtime.Internal.Calibration; | ||
using Microsoft.ML.Runtime.Learners; | ||
using Microsoft.ML.Trainers; | ||
using Xunit; | ||
|
@@ -38,5 +39,33 @@ public void TestEstimatorPoissonRegression() | |
TestEstimatorCore(pipe, dataView); | ||
Done(); | ||
} | ||
|
||
[Fact] | ||
public void TestLogisticRegressionStats() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
combine both tests together #WontFix There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently running into issues with this. Will investigate and log a bug. In reply to: 231416451 [](ancestors = 231416451) |
||
{ | ||
(IEstimator<ITransformer> pipe, IDataView dataView) = GetBinaryClassificationPipeline(); | ||
|
||
pipe = pipe.Append(new LogisticRegression(Env, "Features", "Label", advancedSettings: s => { s.ShowTrainingStats = true; })); | ||
var transformerChain = pipe.Fit(dataView) as TransformerChain<BinaryPredictionTransformer<ParameterMixingCalibratedPredictor>>; | ||
|
||
var linearModel = transformerChain.LastTransformer.Model.SubPredictor as LinearBinaryPredictor; | ||
var stats = linearModel.Statistics; | ||
|
||
LinearModelStatistics.TryGetBiasStatistics(stats, 2, out float stdError, out float zScore, out float pValue); | ||
|
||
Assert.Equal(0.0f, stdError); | ||
Assert.Equal(0.0f, zScore); | ||
Assert.Equal(0.0f, pValue); | ||
|
||
using (var ch = Env.Start("Calcuating STD for LR.")) | ||
linearModel.ComputeExtendedTrainingStatistics(ch); | ||
|
||
LinearModelStatistics.TryGetBiasStatistics(stats, 2, out stdError, out zScore, out pValue); | ||
|
||
Assert.True(stdError > 0); | ||
|
||
Assert.True(zScore > 0); | ||
|
||
Done(); | ||
} | ||
} | ||
} |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function name can be just
compute(...)
because it's insideLogisticRegressionTrainingStats
. #Resolved