-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Better names to calibreated linear classification models #3034
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ba22e12
60d7d9c
d28c5ce
901e706
4302c88
38b9940
32a5e60
b177db1
804c4a8
3aadd3e
ff16bee
7451306
3dfa9cf
7750434
8af0d21
7b7cf8d
7c8e882
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,20 +19,20 @@ | |
using Microsoft.ML.Trainers; | ||
using Microsoft.ML.Transforms; | ||
|
||
[assembly: LoadableClass(typeof(SymbolicSgdTrainer), typeof(SymbolicSgdTrainer.Options), | ||
[assembly: LoadableClass(typeof(SymbolicSgdLogisticRegressionBinaryTrainer), typeof(SymbolicSgdLogisticRegressionBinaryTrainer.Options), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @rogancarr..could you cross-check ? To me the Class names do seem to match up fine, with both having the word
In reply to: 268859710 [](ancestors = 268859710) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah. Our conclusion was having In reply to: 268865148 [](ancestors = 268865148,268859710) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it. I was mixing up catalogs & classes. #Resolved |
||
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) }, | ||
SymbolicSgdTrainer.UserNameValue, | ||
SymbolicSgdTrainer.LoadNameValue, | ||
SymbolicSgdTrainer.ShortName)] | ||
SymbolicSgdLogisticRegressionBinaryTrainer.UserNameValue, | ||
SymbolicSgdLogisticRegressionBinaryTrainer.LoadNameValue, | ||
SymbolicSgdLogisticRegressionBinaryTrainer.ShortName)] | ||
|
||
[assembly: LoadableClass(typeof(void), typeof(SymbolicSgdTrainer), null, typeof(SignatureEntryPointModule), SymbolicSgdTrainer.LoadNameValue)] | ||
[assembly: LoadableClass(typeof(void), typeof(SymbolicSgdLogisticRegressionBinaryTrainer), null, typeof(SignatureEntryPointModule), SymbolicSgdLogisticRegressionBinaryTrainer.LoadNameValue)] | ||
|
||
namespace Microsoft.ML.Trainers | ||
{ | ||
using TPredictor = CalibratedModelParametersBase<LinearBinaryModelParameters, PlattCalibrator>; | ||
|
||
/// <include file='doc.xml' path='doc/members/member[@name="SymSGD"]/*' /> | ||
public sealed class SymbolicSgdTrainer : TrainerEstimatorBase<BinaryPredictionTransformer<TPredictor>, TPredictor> | ||
public sealed class SymbolicSgdLogisticRegressionBinaryTrainer : TrainerEstimatorBase<BinaryPredictionTransformer<TPredictor>, TPredictor> | ||
{ | ||
internal const string LoadNameValue = "SymbolicSGD"; | ||
internal const string UserNameValue = "Symbolic SGD (binary)"; | ||
|
@@ -195,9 +195,9 @@ private protected override TPredictor TrainModelCore(TrainContext context) | |
private protected override PredictionKind PredictionKind => PredictionKind.BinaryClassification; | ||
|
||
/// <summary> | ||
/// Initializes a new instance of <see cref="SymbolicSgdTrainer"/> | ||
/// Initializes a new instance of <see cref="SymbolicSgdLogisticRegressionBinaryTrainer"/> | ||
/// </summary> | ||
internal SymbolicSgdTrainer(IHostEnvironment env, Options options) | ||
internal SymbolicSgdLogisticRegressionBinaryTrainer(IHostEnvironment env, Options options) | ||
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(options.FeatureColumnName), | ||
TrainerUtils.MakeBoolScalarLabel(options.LabelColumnName)) | ||
{ | ||
|
@@ -223,7 +223,7 @@ private protected override BinaryPredictionTransformer<TPredictor> MakeTransform | |
=> new BinaryPredictionTransformer<TPredictor>(Host, model, trainSchema, FeatureColumn.Name); | ||
|
||
/// <summary> | ||
/// Continues the training of <see cref="SymbolicSgdTrainer"/> using an already trained <paramref name="modelParameters"/> | ||
/// Continues the training of <see cref="SymbolicSgdLogisticRegressionBinaryTrainer"/> using an already trained <paramref name="modelParameters"/> | ||
/// a <see cref="BinaryPredictionTransformer"/>. | ||
/// </summary> | ||
public BinaryPredictionTransformer<TPredictor> Fit(IDataView trainData, LinearModelParameters modelParameters) | ||
|
@@ -241,8 +241,8 @@ private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape | |
|
||
[TlcModule.EntryPoint(Name = "Trainers.SymSgdBinaryClassifier", | ||
Desc = "Train a symbolic SGD.", | ||
UserName = SymbolicSgdTrainer.UserNameValue, | ||
ShortName = SymbolicSgdTrainer.ShortName)] | ||
UserName = SymbolicSgdLogisticRegressionBinaryTrainer.UserNameValue, | ||
ShortName = SymbolicSgdLogisticRegressionBinaryTrainer.ShortName)] | ||
internal static CommonOutputs.BinaryClassificationOutput TrainSymSgd(IHostEnvironment env, Options options) | ||
{ | ||
Contracts.CheckValue(env, nameof(env)); | ||
|
@@ -251,7 +251,7 @@ internal static CommonOutputs.BinaryClassificationOutput TrainSymSgd(IHostEnviro | |
EntryPointUtils.CheckInputArgs(host, options); | ||
|
||
return TrainerEntryPointsUtils.Train<Options, CommonOutputs.BinaryClassificationOutput>(host, options, | ||
() => new SymbolicSgdTrainer(host, options), | ||
() => new SymbolicSgdLogisticRegressionBinaryTrainer(host, options), | ||
() => TrainerEntryPointsUtils.FindColumn(host, options.TrainingData.Schema, options.LabelColumnName)); | ||
} | ||
|
||
|
@@ -324,7 +324,7 @@ public void Free() | |
// giving an array, we are at _storage[_storageIndex][_indexInCurArray]. | ||
private int _indexInCurArray; | ||
// This is used to access AccelMemBudget, AccelChunkSize and UsedMemory | ||
private readonly SymbolicSgdTrainer _trainer; | ||
private readonly SymbolicSgdLogisticRegressionBinaryTrainer _trainer; | ||
|
||
private readonly IChannel _ch; | ||
|
||
|
@@ -336,7 +336,7 @@ public void Free() | |
/// </summary> | ||
/// <param name="trainer"></param> | ||
/// <param name="ch"></param> | ||
public ArrayManager(SymbolicSgdTrainer trainer, IChannel ch) | ||
public ArrayManager(SymbolicSgdLogisticRegressionBinaryTrainer trainer, IChannel ch) | ||
{ | ||
_storage = new List<VeryLongArray>(); | ||
// Setting the default value to 2^17. | ||
|
@@ -500,7 +500,7 @@ private sealed class InputDataManager : IDisposable | |
// This is the index to go over the instances in instanceProperties | ||
private int _instanceIndex; | ||
// This is used to access AccelMemBudget, AccelChunkSize and UsedMemory | ||
private readonly SymbolicSgdTrainer _trainer; | ||
private readonly SymbolicSgdLogisticRegressionBinaryTrainer _trainer; | ||
private readonly IChannel _ch; | ||
|
||
// Whether memorySize was big enough to load the entire instances into the buffer | ||
|
@@ -511,7 +511,7 @@ private sealed class InputDataManager : IDisposable | |
// Tells if we have gone through the dataset entirely. | ||
public bool FinishedTheLoad => !_cursorMoveNext; | ||
|
||
public InputDataManager(SymbolicSgdTrainer trainer, FloatLabelCursor.Factory cursorFactory, IChannel ch) | ||
public InputDataManager(SymbolicSgdLogisticRegressionBinaryTrainer trainer, FloatLabelCursor.Factory cursorFactory, IChannel ch) | ||
{ | ||
_instIndices = new ArrayManager<int>(trainer, ch); | ||
_instValues = new ArrayManager<float>(trainer, ch); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do you feel about making the other calibrated linear trainers, like SDCA into
XyzLogisticRegression()
. #ResolvedThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can always add that at a later time. #Resolved