Skip to content

Commit 24926ee

Browse files
authored
Scrubbing LogisticRegression learners (#2761)
* LogisticRegression * LogisticRegression - 2 * added ShortName for entrypoint back compatibility; updated manifest files * also looked into MultiClass LR and PoissonRegression * review comments * review comments * review comments * ensure MLcontext names are same as those used in Options
1 parent 5746ec9 commit 24926ee

File tree

19 files changed

+305
-246
lines changed

19 files changed

+305
-246
lines changed

src/Microsoft.ML.Data/Prediction/Calibrator.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ public abstract class CalibratedModelParametersBase<TSubModel, TCalibrator> :
173173
where TSubModel : class
174174
where TCalibrator : class, ICalibrator
175175
{
176-
protected readonly IHost Host;
176+
private protected readonly IHost Host;
177177

178178
// Strongly-typed members.
179179
/// <summary>

src/Microsoft.ML.Mkl.Components/ComputeLRTrainingStdThroughHal.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ namespace Microsoft.ML.Trainers
1111
{
1212
using MklOls = OrdinaryLeastSquaresRegressionTrainer.Mkl;
1313

14-
public sealed class ComputeLRTrainingStdThroughMkl : ComputeLRTrainingStd
14+
public sealed class ComputeLRTrainingStdThroughMkl : ComputeLogisticRegressionStandardDeviation
1515
{
1616
/// <summary>
1717
/// Computes the standart deviation matrix of each of the non-zero training weights, needed to calculate further the standart deviation,
@@ -23,7 +23,7 @@ public sealed class ComputeLRTrainingStdThroughMkl : ComputeLRTrainingStd
2323
/// <param name="currentWeightsCount"></param>
2424
/// <param name="ch">The <see cref="IChannel"/> used for messaging.</param>
2525
/// <param name="l2Weight">The L2Weight used for training. (Supply the same one that got used during training.)</param>
26-
public override VBuffer<float> ComputeStd(double[] hessian, int[] weightIndices, int numSelectedParams, int currentWeightsCount, IChannel ch, float l2Weight)
26+
public override VBuffer<float> ComputeStandardDeviation(double[] hessian, int[] weightIndices, int numSelectedParams, int currentWeightsCount, IChannel ch, float l2Weight)
2727
{
2828
Contracts.AssertValue(ch);
2929
Contracts.AssertValue(hessian, nameof(hessian));

src/Microsoft.ML.StandardLearners/Standard/LinearModelParameters.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ private protected virtual float Score(in VBuffer<float> src)
261261
return Bias + VectorUtils.DotProduct(in _weightsDense, in src);
262262
}
263263

264-
protected virtual void GetFeatureContributions(in VBuffer<float> features, ref VBuffer<float> contributions, int top, int bottom, bool normalize)
264+
private protected virtual void GetFeatureContributions(in VBuffer<float> features, ref VBuffer<float> contributions, int top, int bottom, bool normalize)
265265
{
266266
if (features.Length != Weight.Length)
267267
throw Contracts.Except("Input is of length {0} does not match expected length of weights {1}", features.Length, Weight.Length);
@@ -662,6 +662,9 @@ IList<KeyValuePair<string, object>> ICanGetSummaryInKeyValuePairs.GetSummaryInKe
662662
}
663663
}
664664

665+
/// <summary>
666+
/// The model parameters class for Poisson Regression.
667+
/// </summary>
665668
public sealed class PoissonRegressionModelParameters : RegressionModelParameters, IParameterMixer<float>
666669
{
667670
internal const string LoaderSignature = "PoissonRegressionExec";

src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs

Lines changed: 77 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -22,36 +22,54 @@ public abstract class LbfgsTrainerBase<TOptions, TTransformer, TModel> : Trainer
2222
{
2323
public abstract class OptionsBase : TrainerInputBaseWithWeight
2424
{
25-
[Argument(ArgumentType.AtMostOnce, HelpText = "L2 regularization weight", ShortName = "l2", SortOrder = 50)]
25+
/// <summary>
26+
/// L2 regularization weight.
27+
/// </summary>
28+
[Argument(ArgumentType.AtMostOnce, HelpText = "L2 regularization weight", ShortName = "l2, L2Weight", SortOrder = 50)]
2629
[TGUI(Label = "L2 Weight", Description = "Weight of L2 regularizer term", SuggestedSweeps = "0,0.1,1")]
2730
[TlcModule.SweepableFloatParamAttribute(0.0f, 1.0f, numSteps: 4)]
28-
public float L2Weight = Defaults.L2Weight;
31+
public float L2Regularization = Defaults.L2Regularization;
2932

30-
[Argument(ArgumentType.AtMostOnce, HelpText = "L1 regularization weight", ShortName = "l1", SortOrder = 50)]
33+
/// <summary>
34+
/// L1 regularization weight.
35+
/// </summary>
36+
[Argument(ArgumentType.AtMostOnce, HelpText = "L1 regularization weight", ShortName = "l1, L1Weight", SortOrder = 50)]
3137
[TGUI(Label = "L1 Weight", Description = "Weight of L1 regularizer term", SuggestedSweeps = "0,0.1,1")]
3238
[TlcModule.SweepableFloatParamAttribute(0.0f, 1.0f, numSteps: 4)]
33-
public float L1Weight = Defaults.L1Weight;
39+
public float L1Regularization = Defaults.L1Regularization;
3440

41+
/// <summary>
42+
/// Tolerance parameter for optimization convergence. (Low = slower, more accurate).
43+
/// </summary>
3544
[Argument(ArgumentType.AtMostOnce, HelpText = "Tolerance parameter for optimization convergence. Low = slower, more accurate",
36-
ShortName = "ot", SortOrder = 50)]
45+
ShortName = "ot, OptTol", SortOrder = 50)]
3746
[TGUI(Label = "Optimization Tolerance", Description = "Threshold for optimizer convergence", SuggestedSweeps = "1e-4,1e-7")]
3847
[TlcModule.SweepableDiscreteParamAttribute(new object[] { 1e-4f, 1e-7f })]
39-
public float OptTol = Defaults.OptTol;
48+
public float OptmizationTolerance = Defaults.OptimizationTolerance;
4049

50+
/// <summary>
51+
/// Number of previous iterations to remember for estimate of Hessian.
52+
/// </summary>
4153
[Argument(ArgumentType.AtMostOnce, HelpText = "Memory size for L-BFGS. Low=faster, less accurate",
42-
ShortName = "m", SortOrder = 50)]
54+
ShortName = "m, MemorySize", SortOrder = 50)]
4355
[TGUI(Description = "Memory size for L-BFGS", SuggestedSweeps = "5,20,50")]
4456
[TlcModule.SweepableDiscreteParamAttribute("MemorySize", new object[] { 5, 20, 50 })]
45-
public int MemorySize = Defaults.MemorySize;
57+
public int IterationsToRemember = Defaults.IterationsToRemember;
4658

47-
[Argument(ArgumentType.AtMostOnce, HelpText = "Maximum iterations.", ShortName = "maxiter")]
59+
/// <summary>
60+
/// Number of iterations.
61+
/// </summary>
62+
[Argument(ArgumentType.AtMostOnce, HelpText = "Maximum iterations.", ShortName = "maxiter, MaxIterations")]
4863
[TGUI(Label = "Max Number of Iterations")]
4964
[TlcModule.SweepableLongParamAttribute("MaxIterations", 1, int.MaxValue)]
50-
public int MaxIterations = Defaults.MaxIterations;
65+
public int NumberOfIterations = Defaults.NumberOfIterations;
5166

67+
/// <summary>
68+
/// Run SGD to initialize LR weights, converging to this tolerance.
69+
/// </summary>
5270
[Argument(ArgumentType.AtMostOnce, HelpText = "Run SGD to initialize LR weights, converging to this tolerance",
53-
ShortName = "sgd")]
54-
public float SgdInitializationTolerance = 0;
71+
ShortName = "sgd, SgdInitializationTolerance")]
72+
public float StochasticGradientDescentInitilaizationTolerance = 0;
5573

5674
/// <summary>
5775
/// Features must occur in at least this many instances to be included
@@ -68,37 +86,43 @@ public abstract class OptionsBase : TrainerInputBaseWithWeight
6886
/// <summary>
6987
/// Init Weights Diameter
7088
/// </summary>
71-
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Init weights diameter", ShortName = "initwts", SortOrder = 140)]
89+
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Init weights diameter", ShortName = "initwts, InitWtsDiameter", SortOrder = 140)]
7290
[TGUI(Label = "Initial Weights Scale", SuggestedSweeps = "0,0.1,0.5,1")]
7391
[TlcModule.SweepableFloatParamAttribute("InitWtsDiameter", 0.0f, 1.0f, numSteps: 5)]
74-
public float InitWtsDiameter = 0;
92+
public float InitialWeightsDiameter = 0;
7593

7694
// Deprecated
7795
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether or not to use threads. Default is true",
7896
ShortName = "t", Hide = true)]
79-
public bool UseThreads = true;
97+
internal bool UseThreads = true;
8098

8199
/// <summary>
82100
/// Number of threads. Null means use the number of processors.
83101
/// </summary>
84-
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of threads", ShortName = "nt")]
85-
public int? NumThreads;
102+
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of threads", ShortName = "nt, NumThreads")]
103+
public int? NumberOfThreads;
86104

105+
/// <summary>
106+
/// Force densification of the internal optimization vectors. Default is false.
107+
/// </summary>
87108
[Argument(ArgumentType.AtMostOnce, HelpText = "Force densification of the internal optimization vectors", ShortName = "do")]
88109
[TlcModule.SweepableDiscreteParamAttribute("DenseOptimizer", new object[] { false, true })]
89110
public bool DenseOptimizer = false;
90111

112+
/// <summary>
113+
/// Enforce non-negative weights. Default is false.
114+
/// </summary>
91115
[Argument(ArgumentType.AtMostOnce, HelpText = "Enforce non-negative weights", ShortName = "nn", SortOrder = 90)]
92116
public bool EnforceNonNegativity = Defaults.EnforceNonNegativity;
93117

94118
[BestFriend]
95119
internal static class Defaults
96120
{
97-
public const float L2Weight = 1;
98-
public const float L1Weight = 1;
99-
public const float OptTol = 1e-7f;
100-
public const int MemorySize = 20;
101-
public const int MaxIterations = int.MaxValue;
121+
public const float L2Regularization = 1;
122+
public const float L1Regularization = 1;
123+
public const float OptimizationTolerance = 1e-7f;
124+
public const int IterationsToRemember = 20;
125+
public const int NumberOfIterations = int.MaxValue;
102126
public const bool EnforceNonNegativity = false;
103127
}
104128
}
@@ -165,10 +189,10 @@ internal LbfgsTrainerBase(IHostEnvironment env,
165189
FeatureColumnName = featureColumn,
166190
LabelColumnName = labelColumn.Name,
167191
ExampleWeightColumnName = weightColumn,
168-
L1Weight = l1Weight,
169-
L2Weight = l2Weight,
170-
OptTol = optimizationTolerance,
171-
MemorySize = memorySize,
192+
L1Regularization = l1Weight,
193+
L2Regularization = l2Weight,
194+
OptmizationTolerance = optimizationTolerance,
195+
IterationsToRemember = memorySize,
172196
EnforceNonNegativity = enforceNoNegativity
173197
},
174198
labelColumn)
@@ -191,31 +215,31 @@ internal LbfgsTrainerBase(IHostEnvironment env,
191215
options.FeatureColumnName = FeatureColumn.Name;
192216
options.LabelColumnName = LabelColumn.Name;
193217
options.ExampleWeightColumnName = WeightColumn.Name;
194-
Host.CheckUserArg(!LbfgsTrainerOptions.UseThreads || LbfgsTrainerOptions.NumThreads > 0 || LbfgsTrainerOptions.NumThreads == null,
195-
nameof(LbfgsTrainerOptions.NumThreads), "numThreads must be positive (or empty for default)");
196-
Host.CheckUserArg(LbfgsTrainerOptions.L2Weight >= 0, nameof(LbfgsTrainerOptions.L2Weight), "Must be non-negative");
197-
Host.CheckUserArg(LbfgsTrainerOptions.L1Weight >= 0, nameof(LbfgsTrainerOptions.L1Weight), "Must be non-negative");
198-
Host.CheckUserArg(LbfgsTrainerOptions.OptTol > 0, nameof(LbfgsTrainerOptions.OptTol), "Must be positive");
199-
Host.CheckUserArg(LbfgsTrainerOptions.MemorySize > 0, nameof(LbfgsTrainerOptions.MemorySize), "Must be positive");
200-
Host.CheckUserArg(LbfgsTrainerOptions.MaxIterations > 0, nameof(LbfgsTrainerOptions.MaxIterations), "Must be positive");
201-
Host.CheckUserArg(LbfgsTrainerOptions.SgdInitializationTolerance >= 0, nameof(LbfgsTrainerOptions.SgdInitializationTolerance), "Must be non-negative");
202-
Host.CheckUserArg(LbfgsTrainerOptions.NumThreads == null || LbfgsTrainerOptions.NumThreads.Value >= 0, nameof(LbfgsTrainerOptions.NumThreads), "Must be non-negative");
203-
204-
Host.CheckParam(!(LbfgsTrainerOptions.L2Weight < 0), nameof(LbfgsTrainerOptions.L2Weight), "Must be non-negative, if provided.");
205-
Host.CheckParam(!(LbfgsTrainerOptions.L1Weight < 0), nameof(LbfgsTrainerOptions.L1Weight), "Must be non-negative, if provided");
206-
Host.CheckParam(!(LbfgsTrainerOptions.OptTol <= 0), nameof(LbfgsTrainerOptions.OptTol), "Must be positive, if provided.");
207-
Host.CheckParam(!(LbfgsTrainerOptions.MemorySize <= 0), nameof(LbfgsTrainerOptions.MemorySize), "Must be positive, if provided.");
208-
209-
L2Weight = LbfgsTrainerOptions.L2Weight;
210-
L1Weight = LbfgsTrainerOptions.L1Weight;
211-
OptTol = LbfgsTrainerOptions.OptTol;
212-
MemorySize =LbfgsTrainerOptions.MemorySize;
213-
MaxIterations = LbfgsTrainerOptions.MaxIterations;
214-
SgdInitializationTolerance = LbfgsTrainerOptions.SgdInitializationTolerance;
218+
Host.CheckUserArg(!LbfgsTrainerOptions.UseThreads || LbfgsTrainerOptions.NumberOfThreads > 0 || LbfgsTrainerOptions.NumberOfThreads == null,
219+
nameof(LbfgsTrainerOptions.NumberOfThreads), "Must be positive (or empty for default)");
220+
Host.CheckUserArg(LbfgsTrainerOptions.L2Regularization >= 0, nameof(LbfgsTrainerOptions.L2Regularization), "Must be non-negative");
221+
Host.CheckUserArg(LbfgsTrainerOptions.L1Regularization >= 0, nameof(LbfgsTrainerOptions.L1Regularization), "Must be non-negative");
222+
Host.CheckUserArg(LbfgsTrainerOptions.OptmizationTolerance > 0, nameof(LbfgsTrainerOptions.OptmizationTolerance), "Must be positive");
223+
Host.CheckUserArg(LbfgsTrainerOptions.IterationsToRemember > 0, nameof(LbfgsTrainerOptions.IterationsToRemember), "Must be positive");
224+
Host.CheckUserArg(LbfgsTrainerOptions.NumberOfIterations > 0, nameof(LbfgsTrainerOptions.NumberOfIterations), "Must be positive");
225+
Host.CheckUserArg(LbfgsTrainerOptions.StochasticGradientDescentInitilaizationTolerance >= 0, nameof(LbfgsTrainerOptions.StochasticGradientDescentInitilaizationTolerance), "Must be non-negative");
226+
Host.CheckUserArg(LbfgsTrainerOptions.NumberOfThreads == null || LbfgsTrainerOptions.NumberOfThreads.Value >= 0, nameof(LbfgsTrainerOptions.NumberOfThreads), "Must be non-negative");
227+
228+
Host.CheckParam(!(LbfgsTrainerOptions.L2Regularization < 0), nameof(LbfgsTrainerOptions.L2Regularization), "Must be non-negative, if provided.");
229+
Host.CheckParam(!(LbfgsTrainerOptions.L1Regularization < 0), nameof(LbfgsTrainerOptions.L1Regularization), "Must be non-negative, if provided");
230+
Host.CheckParam(!(LbfgsTrainerOptions.OptmizationTolerance <= 0), nameof(LbfgsTrainerOptions.OptmizationTolerance), "Must be positive, if provided.");
231+
Host.CheckParam(!(LbfgsTrainerOptions.IterationsToRemember <= 0), nameof(LbfgsTrainerOptions.IterationsToRemember), "Must be positive, if provided.");
232+
233+
L2Weight = LbfgsTrainerOptions.L2Regularization;
234+
L1Weight = LbfgsTrainerOptions.L1Regularization;
235+
OptTol = LbfgsTrainerOptions.OptmizationTolerance;
236+
MemorySize =LbfgsTrainerOptions.IterationsToRemember;
237+
MaxIterations = LbfgsTrainerOptions.NumberOfIterations;
238+
SgdInitializationTolerance = LbfgsTrainerOptions.StochasticGradientDescentInitilaizationTolerance;
215239
Quiet = LbfgsTrainerOptions.Quiet;
216-
InitWtsDiameter = LbfgsTrainerOptions.InitWtsDiameter;
240+
InitWtsDiameter = LbfgsTrainerOptions.InitialWeightsDiameter;
217241
UseThreads = LbfgsTrainerOptions.UseThreads;
218-
NumThreads = LbfgsTrainerOptions.NumThreads;
242+
NumThreads = LbfgsTrainerOptions.NumberOfThreads;
219243
DenseOptimizer = LbfgsTrainerOptions.DenseOptimizer;
220244
EnforceNonNegativity = LbfgsTrainerOptions.EnforceNonNegativity;
221245

@@ -245,10 +269,10 @@ private static TOptions ArgsInit(string featureColumn, SchemaShape.Column labelC
245269
FeatureColumnName = featureColumn,
246270
LabelColumnName = labelColumn.Name,
247271
ExampleWeightColumnName = weightColumn,
248-
L1Weight = l1Weight,
249-
L2Weight = l2Weight,
250-
OptTol = optimizationTolerance,
251-
MemorySize = memorySize,
272+
L1Regularization = l1Weight,
273+
L2Regularization = l2Weight,
274+
OptmizationTolerance = optimizationTolerance,
275+
IterationsToRemember = memorySize,
252276
EnforceNonNegativity = enforceNoNegativity
253277
};
254278

0 commit comments

Comments
 (0)