-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Modify API for advanced settings (several learners) #2163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
961fe9d
91939d9
fed254d
c447d96
7d15dfb
5cf6c19
c89741f
5b316f2
ea9a5de
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,7 @@ | |
using Microsoft.ML.Training; | ||
using Newtonsoft.Json.Linq; | ||
|
||
[assembly: LoadableClass(typeof(MulticlassLogisticRegression), typeof(MulticlassLogisticRegression.Arguments), | ||
[assembly: LoadableClass(typeof(MulticlassLogisticRegression), typeof(MulticlassLogisticRegression.Options), | ||
new[] { typeof(SignatureMultiClassClassifierTrainer), typeof(SignatureTrainer) }, | ||
MulticlassLogisticRegression.UserNameValue, | ||
MulticlassLogisticRegression.LoadNameValue, | ||
|
@@ -38,14 +38,14 @@ namespace Microsoft.ML.Learners | |
{ | ||
/// <include file = 'doc.xml' path='doc/members/member[@name="LBFGS"]/*' /> | ||
/// <include file = 'doc.xml' path='docs/members/example[@name="LogisticRegressionClassifier"]/*' /> | ||
public sealed class MulticlassLogisticRegression : LbfgsTrainerBase<MulticlassLogisticRegression.Arguments, | ||
public sealed class MulticlassLogisticRegression : LbfgsTrainerBase<MulticlassLogisticRegression.Options, | ||
MulticlassPredictionTransformer<MulticlassLogisticRegressionModelParameters>, MulticlassLogisticRegressionModelParameters> | ||
{ | ||
public const string LoadNameValue = "MultiClassLogisticRegression"; | ||
internal const string UserNameValue = "Multi-class Logistic Regression"; | ||
internal const string ShortName = "mlr"; | ||
|
||
public sealed class Arguments : ArgumentsBase | ||
public sealed class Options : ArgumentsBase | ||
{ | ||
[Argument(ArgumentType.AtMostOnce, HelpText = "Show statistics of training examples.", ShortName = "stat", SortOrder = 50)] | ||
public bool ShowTrainingStats = false; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
xml #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
@@ -82,19 +82,16 @@ public sealed class Arguments : ArgumentsBase | |
/// <param name="l2Weight">Weight of L2 regularizer term.</param> | ||
/// <param name="memorySize">Memory size for <see cref="LogisticRegression"/>. Low=faster, less accurate.</param> | ||
/// <param name="optimizationTolerance">Threshold for optimizer convergence.</param> | ||
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param> | ||
public MulticlassLogisticRegression(IHostEnvironment env, | ||
internal MulticlassLogisticRegression(IHostEnvironment env, | ||
string labelColumn = DefaultColumnNames.Label, | ||
string featureColumn = DefaultColumnNames.Features, | ||
string weights = null, | ||
float l1Weight = Arguments.Defaults.L1Weight, | ||
float l2Weight = Arguments.Defaults.L2Weight, | ||
float optimizationTolerance = Arguments.Defaults.OptTol, | ||
int memorySize = Arguments.Defaults.MemorySize, | ||
bool enforceNoNegativity = Arguments.Defaults.EnforceNonNegativity, | ||
Action<Arguments> advancedSettings = null) | ||
: base(env, featureColumn, TrainerUtils.MakeU4ScalarColumn(labelColumn), weights, advancedSettings, | ||
l1Weight, l2Weight, optimizationTolerance, memorySize, enforceNoNegativity) | ||
float l1Weight = Options.Defaults.L1Weight, | ||
float l2Weight = Options.Defaults.L2Weight, | ||
float optimizationTolerance = Options.Defaults.OptTol, | ||
int memorySize = Options.Defaults.MemorySize, | ||
bool enforceNoNegativity = Options.Defaults.EnforceNonNegativity) | ||
: base(env, featureColumn, TrainerUtils.MakeU4ScalarColumn(labelColumn), weights, l1Weight, l2Weight, optimizationTolerance, memorySize, enforceNoNegativity) | ||
{ | ||
Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); | ||
Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); | ||
|
@@ -105,8 +102,8 @@ public MulticlassLogisticRegression(IHostEnvironment env, | |
/// <summary> | ||
/// Initializes a new instance of <see cref="MulticlassLogisticRegression"/> | ||
/// </summary> | ||
internal MulticlassLogisticRegression(IHostEnvironment env, Arguments args) | ||
: base(env, args, TrainerUtils.MakeU4ScalarColumn(args.LabelColumn)) | ||
internal MulticlassLogisticRegression(IHostEnvironment env, Options options) | ||
: base(env, options, TrainerUtils.MakeU4ScalarColumn(options.LabelColumn)) | ||
{ | ||
ShowTrainingStats = Args.ShowTrainingStats; | ||
} | ||
|
@@ -1007,14 +1004,14 @@ public partial class LogisticRegression | |
ShortName = MulticlassLogisticRegression.ShortName, | ||
XmlInclude = new[] { @"<include file='../Microsoft.ML.StandardLearners/Standard/LogisticRegression/doc.xml' path='doc/members/member[@name=""LBFGS""]/*' />", | ||
@"<include file='../Microsoft.ML.StandardLearners/Standard/LogisticRegression/doc.xml' path='doc/members/example[@name=""LogisticRegressionClassifier""]/*' />" })] | ||
public static CommonOutputs.MulticlassClassificationOutput TrainMultiClass(IHostEnvironment env, MulticlassLogisticRegression.Arguments input) | ||
public static CommonOutputs.MulticlassClassificationOutput TrainMultiClass(IHostEnvironment env, MulticlassLogisticRegression.Options input) | ||
{ | ||
Contracts.CheckValue(env, nameof(env)); | ||
var host = env.Register("TrainLRMultiClass"); | ||
host.CheckValue(input, nameof(input)); | ||
EntryPointUtils.CheckInputArgs(host, input); | ||
|
||
return LearnerEntryPointsUtils.Train<MulticlassLogisticRegression.Arguments, CommonOutputs.MulticlassClassificationOutput>(host, input, | ||
return LearnerEntryPointsUtils.Train<MulticlassLogisticRegression.Options, CommonOutputs.MulticlassClassificationOutput>(host, input, | ||
() => new MulticlassLogisticRegression(host, input), | ||
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn), | ||
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn)); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,7 @@ | |
using Microsoft.ML.Trainers.Online; | ||
using Microsoft.ML.Training; | ||
|
||
[assembly: LoadableClass(AveragedPerceptronTrainer.Summary, typeof(AveragedPerceptronTrainer), typeof(AveragedPerceptronTrainer.Arguments), | ||
[assembly: LoadableClass(AveragedPerceptronTrainer.Summary, typeof(AveragedPerceptronTrainer), typeof(AveragedPerceptronTrainer.Options), | ||
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) }, | ||
AveragedPerceptronTrainer.UserNameValue, | ||
AveragedPerceptronTrainer.LoadNameValue, "avgper", AveragedPerceptronTrainer.ShortName)] | ||
|
@@ -37,9 +37,9 @@ public sealed class AveragedPerceptronTrainer : AveragedLinearTrainer<BinaryPred | |
internal const string ShortName = "ap"; | ||
internal const string Summary = "Averaged Perceptron Binary Classifier."; | ||
|
||
private readonly Arguments _args; | ||
private readonly Options _args; | ||
|
||
public sealed class Arguments : AveragedLinearArguments | ||
public sealed class Options : AveragedLinearArguments | ||
{ | ||
[Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)] | ||
public ISupportClassificationLossFactory LossFunction = new HingeLoss.Arguments(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
xml #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
@@ -83,10 +83,10 @@ public override LinearBinaryModelParameters CreatePredictor() | |
} | ||
} | ||
|
||
internal AveragedPerceptronTrainer(IHostEnvironment env, Arguments args) | ||
: base(args, env, UserNameValue, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn)) | ||
internal AveragedPerceptronTrainer(IHostEnvironment env, Options options) | ||
: base(options, env, UserNameValue, TrainerUtils.MakeBoolScalarLabel(options.LabelColumn)) | ||
{ | ||
_args = args; | ||
_args = options; | ||
LossFunction = _args.LossFunction.CreateComponent(env); | ||
} | ||
|
||
|
@@ -103,18 +103,16 @@ internal AveragedPerceptronTrainer(IHostEnvironment env, Arguments args) | |
/// <param name="decreaseLearningRate">Wheather to decrease learning rate as iterations progress.</param> | ||
/// <param name="l2RegularizerWeight">L2 Regularization Weight.</param> | ||
/// <param name="numIterations">The number of training iteraitons.</param> | ||
/// <param name="advancedSettings">A delegate to supply more advanced arguments to the algorithm.</param> | ||
public AveragedPerceptronTrainer(IHostEnvironment env, | ||
internal AveragedPerceptronTrainer(IHostEnvironment env, | ||
string labelColumn = DefaultColumnNames.Label, | ||
string featureColumn = DefaultColumnNames.Features, | ||
string weights = null, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. delete.. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do see it being used to initialize In reply to: 248544047 [](ancestors = 248544047) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
IClassificationLoss lossFunction = null, | ||
float learningRate = Arguments.AveragedDefaultArgs.LearningRate, | ||
bool decreaseLearningRate = Arguments.AveragedDefaultArgs.DecreaseLearningRate, | ||
float l2RegularizerWeight = Arguments.AveragedDefaultArgs.L2RegularizerWeight, | ||
int numIterations = Arguments.AveragedDefaultArgs.NumIterations, | ||
Action<Arguments> advancedSettings = null) | ||
: this(env, InvokeAdvanced(advancedSettings, new Arguments | ||
float learningRate = Options.AveragedDefaultArgs.LearningRate, | ||
bool decreaseLearningRate = Options.AveragedDefaultArgs.DecreaseLearningRate, | ||
float l2RegularizerWeight = Options.AveragedDefaultArgs.L2RegularizerWeight, | ||
int numIterations = Options.AveragedDefaultArgs.NumIterations) | ||
: this(env, new Options | ||
{ | ||
LabelColumn = labelColumn, | ||
FeatureColumn = featureColumn, | ||
|
@@ -124,7 +122,7 @@ public AveragedPerceptronTrainer(IHostEnvironment env, | |
L2RegularizerWeight = l2RegularizerWeight, | ||
NumIterations = numIterations, | ||
LossFunction = new TrivialFactory(lossFunction ?? new HingeLoss()) | ||
})) | ||
}) | ||
{ | ||
} | ||
|
||
|
@@ -191,14 +189,14 @@ public BinaryPredictionTransformer<LinearBinaryModelParameters> Train(IDataView | |
ShortName = ShortName, | ||
XmlInclude = new[] { @"<include file='../Microsoft.ML.StandardLearners/Standard/Online/doc.xml' path='doc/members/member[@name=""AP""]/*' />", | ||
@"<include file='../Microsoft.ML.StandardLearners/Standard/Online/doc.xml' path='doc/members/example[@name=""AP""]/*' />"})] | ||
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, Arguments input) | ||
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, Options input) | ||
{ | ||
Contracts.CheckValue(env, nameof(env)); | ||
var host = env.Register("TrainAP"); | ||
host.CheckValue(input, nameof(input)); | ||
EntryPointUtils.CheckInputArgs(host, input); | ||
|
||
return LearnerEntryPointsUtils.Train<Arguments, CommonOutputs.BinaryClassificationOutput>(host, input, | ||
return LearnerEntryPointsUtils.Train<Options, CommonOutputs.BinaryClassificationOutput>(host, input, | ||
() => new AveragedPerceptronTrainer(host, input), | ||
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn), | ||
calibrator: input.Calibrator, maxCalibrationExamples: input.MaxCalibrationExamples); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'd just delete it.. #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do see it being passed to base constructor
LbfgsTrainerBase
where its used to initializeWeightsColumn
Thoughts ?In reply to: 248543741 [](ancestors = 248543741)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ctor. Sorry for the bad selection.
In reply to: 248769580 [](ancestors = 248769580,248543741)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Deleting the constructor u mean ? That would be #2100
In reply to: 248834382 [](ancestors = 248834382,248769580,248543741)