-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Conversion of Parallel Stochastic Gradient Descent (SymSGD) to estimator #1012
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…still need to pass the weight column
public sealed class SymSgdClassificationTrainer : | ||
TrainerBase<TPredictor>, | ||
ITrainer<TPredictor> | ||
public sealed class SymSgdClassificationTrainer : TrainerEstimatorBase<BinaryPredictionTransformer<TPredictor>, TPredictor> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TrainerEstimatorBase [](start = 54, length = 20)
SingleFeatureTrainerEstimatorBase #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I can't find that. what is the name of the base class I
should derive from?
In reply to: 220056284 [](ancestors = 220056284)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad, this is fine. got it confused with the TransformerPredictors
In reply to: 220256891 [](ancestors = 220256891,220256710,220056284)
public sealed class SymSgdClassificationTrainer : | ||
TrainerBase<TPredictor>, | ||
ITrainer<TPredictor> | ||
public sealed class SymSgdClassificationTrainer : TrainerEstimatorBase<BinaryPredictionTransformer<TPredictor>, TPredictor> | ||
{ | ||
public const string LoadNameValue = "SymbolicSGD"; | ||
public const string UserNameValue = "Symbolic SGD (binary)"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
public [](start = 8, length = 6)
internal, all 4 if you can. #Closed
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(featureColumn), | ||
TrainerUtils.MakeR4ScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) | ||
{ | ||
var args = new Arguments(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
var args [](start = 12, length = 8)
you can use _args directly. #Closed
/// </summary> | ||
internal SymSgdClassificationTrainer(IHostEnvironment env, Arguments args) | ||
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), | ||
TrainerUtils.MakeR4ScalarLabel(args.LabelColumn), TrainerUtils.MakeR4ScalarWeightColumn(null)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MakeR4ScalarLabel [](start = 31, length = 17)
bool #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need a scalar bool label for binary classification.
In reply to: 220057042 [](ancestors = 220057042)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the catch, I had only modified it above and forgot this instance.
In reply to: 220299186 [](ancestors = 220299186,220057042)
new TextLoader.Column("Features", DataKind.R4, new [] { new TextLoader.Range(1, 9) } ) | ||
} | ||
}).Read(new MultiFileSource(GetDataPath(TestDatasets.breastCancer.trainFilename))); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a GetBinaryPipeline() methof on TrainerEstimators. See if it helps reduce duplication. #Closed
// Apply the advanced args, if the user supplied any. | ||
_args.Check(Host); | ||
advancedSettings?.Invoke(_args); | ||
_args.FeatureColumn = featureColumn; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
args.FeatureColumn [](start = 13, length = 18)
weights column #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Fact] | ||
public void TestEstimatorSymSgdClassificationTrainer() | ||
{ | ||
(IEstimator<ITransformer> pipe, IDataView dataView) = GetBinaryClassificationPipeline(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IEstimator [](start = 13, length = 24)
For things like this don't be afraid of var
. #Closed
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param> | ||
/// <param name="labelColumn">The name of the label column.</param> | ||
/// <param name="featureColumn">The name of the feature column.</param> | ||
/// <param name="weightColumn">The name for the column containing the initial weight.</param> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
param name="weightColumn">The name for the column containing the initial weight. [](start = 13, length = 88)
If we don't support weighted training this should be removed, at least in its description of a column.
I think what was actually intended, because it says "initial weight," was that when you train, an initial predictor from which you get initial weights should be supported. (e.g., as we see here:)
var initPred = context.InitialPredictor; |
However this would be a parameter to the training/fitting method, not the constructor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is, I think the request was to have something akin to this auxiliary method here:
machinelearning/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs
Line 1495 in 893a385
public BinaryPredictionTransformer<TScalarPredictor> Train(IDataView trainData, IDataView validationData = null, IPredictor initialPredictor = null) => TrainTransformer(trainData, validationData, initialPredictor); |
In reply to: 220340515 [](ancestors = 220340515)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is actually already taken care of.
TrainerEstimatorBase has TrainTransformer method, which takes an initial predictor, and passes that to TrainModelCore through the TrainContext.
In reply to: 220342250 [](ancestors = 220342250,220340515)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're talking about this I guess.
machinelearning/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs
Lines 127 to 128 in d7b062d
protected TTransformer TrainTransformer(IDataView trainSet, | |
IDataView validationSet = null, IPredictor initPredictor = null) |
@artidoro, no, as you see this method is protected. It is not accessible to people that instantiate the class. So, no, not taken care of.
Your PR was already merged, so whoops, too late I guess. Maybe we can file an issue to fix this post 0.6.
In reply to: 220345387 [](ancestors = 220345387,220342250,220340515)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After discussing with Pete, we should keep this as an advanced setting. The default value is 0 right now. In reply to: 424204987 [](ancestors = 424204987) Refers to: src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs:63 in 6697257. [](commit_id = 6697257, deletion_comment = False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ongoing work on converting the trainers to estimators. This PR converts the Parallel Stochastic Gradient Descent classification trainer (SymSGD).