Skip to content

Scrubbing OneVsAll #2743

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Microsoft.ML.EntryPoints/ModelOperations.cs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ public static PredictorModelOutput CombineOvaModels(IHostEnvironment env, Combin
return new PredictorModelOutput
{
PredictorModel = new PredictorModelImpl(env, data, input.TrainingData,
OvaModelParameters.Create(host, input.UseProbabilities,
OneVersusAllModelParameters.Create(host, input.UseProbabilities,
input.ModelArray.Select(p => p.Predictor as IPredictorProducing<float>).ToArray()))
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ public static (Vector<float> score, Key<uint, TVal> predictedLabel)
int? minDataPerLeaf = null,
double? learningRate = null,
int numBoostRound = Options.Defaults.NumBoostRound,
Action<OvaModelParameters> onFit = null)
Action<OneVersusAllModelParameters> onFit = null)
{
CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, onFit);

Expand Down Expand Up @@ -343,7 +343,7 @@ public static (Vector<float> score, Key<uint, TVal> predictedLabel)
Vector<float> features,
Scalar<float> weights,
Options options,
Action<OvaModelParameters> onFit = null)
Action<OneVersusAllModelParameters> onFit = null)
{
CheckUserValues(label, features, weights, options, onFit);

Expand Down
16 changes: 8 additions & 8 deletions src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace Microsoft.ML.LightGBM
{

/// <include file='doc.xml' path='doc/members/member[@name="LightGBM"]/*' />
public sealed class LightGbmMulticlassTrainer : LightGbmTrainerBase<VBuffer<float>, MulticlassPredictionTransformer<OvaModelParameters>, OvaModelParameters>
public sealed class LightGbmMulticlassTrainer : LightGbmTrainerBase<VBuffer<float>, MulticlassPredictionTransformer<OneVersusAllModelParameters>, OneVersusAllModelParameters>
{
internal const string Summary = "LightGBM Multi Class Classifier";
internal const string LoadNameValue = "LightGBMMulticlass";
Expand Down Expand Up @@ -80,7 +80,7 @@ private LightGbmBinaryModelParameters CreateBinaryPredictor(int classID, string
return new LightGbmBinaryModelParameters(Host, GetBinaryEnsemble(classID), FeatureCount, innerArgs);
}

private protected override OvaModelParameters CreatePredictor()
private protected override OneVersusAllModelParameters CreatePredictor()
{
Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete.");

Expand All @@ -97,9 +97,9 @@ private protected override OvaModelParameters CreatePredictor()
}
string obj = (string)GetGbmParameters()["objective"];
if (obj == "multiclass")
return OvaModelParameters.Create(Host, OvaModelParameters.OutputFormula.Softmax, predictors);
return OneVersusAllModelParameters.Create(Host, OneVersusAllModelParameters.OutputFormula.Softmax, predictors);
else
return OvaModelParameters.Create(Host, predictors);
return OneVersusAllModelParameters.Create(Host, predictors);
}

private protected override void CheckDataValid(IChannel ch, RoleMappedData data)
Expand Down Expand Up @@ -218,14 +218,14 @@ private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape
};
}

private protected override MulticlassPredictionTransformer<OvaModelParameters> MakeTransformer(OvaModelParameters model, DataViewSchema trainSchema)
=> new MulticlassPredictionTransformer<OvaModelParameters>(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name);
private protected override MulticlassPredictionTransformer<OneVersusAllModelParameters> MakeTransformer(OneVersusAllModelParameters model, DataViewSchema trainSchema)
=> new MulticlassPredictionTransformer<OneVersusAllModelParameters>(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name);

/// <summary>
/// Trains a <see cref="LightGbmMulticlassTrainer"/> using both training and validation data, returns
/// a <see cref="MulticlassPredictionTransformer{OvaModelParameters}"/>.
/// a <see cref="MulticlassPredictionTransformer{OneVsAllModelParameters}"/>.
/// </summary>
public MulticlassPredictionTransformer<OvaModelParameters> Fit(IDataView trainData, IDataView validationData)
public MulticlassPredictionTransformer<OneVersusAllModelParameters> Fit(IDataView trainData, IDataView validationData)
=> TrainTransformer(trainData, validationData);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public abstract class MetaMulticlassTrainer<TTransformer, TModel> : ITrainerEsti
where TTransformer : ISingleFeaturePredictionTransformer<TModel>
where TModel : class
{
public abstract class OptionsBase
internal abstract class OptionsBase
Copy link
Member

@wschin wschin Feb 28, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we internalize an Options of a trainer? #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is internal because the only class that derives from it, Options for pkpd and ova are both internal.

The option classes are internal because there are no advanced options and all of the arguments are exposed in the argument-based constructor.


In reply to: 261015420 [](ancestors = 261015420)

{
[Argument(ArgumentType.Multiple, HelpText = "Base predictor", ShortName = "p", SortOrder = 4, SignatureType = typeof(SignatureBinaryClassifierTrainer))]
[TGUI(Label = "Predictor Type", Description = "Type of underlying binary predictor")]
Expand All @@ -39,7 +39,7 @@ public abstract class OptionsBase
/// <summary>
/// The label column that the trainer expects.
/// </summary>
public readonly SchemaShape.Column LabelColumn;
private protected readonly SchemaShape.Column LabelColumn;

private protected readonly OptionsBase Args;
private protected readonly IHost Host;
Expand Down
Loading