Skip to content

Fixing ModelParameter discrepancies #2968

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 19, 2019
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ private class BinaryOutputRow
private readonly static Action<ContinuousInputRow, BinaryOutputRow> GreaterThanAverage = (input, output)
=> output.AboveAverage = input.MedianHomeValue > 22.6;

public static float[] GetLinearModelWeights(OrdinaryLeastSquaresRegressionModelParameters linearModel)
public static float[] GetLinearModelWeights(OlsModelParameters linearModel)
{
return linearModel.Weights.ToArray();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public static void Example()
// we could do so by tweaking the 'advancedSetting'.
var advancedPipeline = mlContext.Transforms.Text.FeaturizeText("SentimentText", "Features")
.Append(mlContext.BinaryClassification.Trainers.SdcaCalibrated(
new SdcaCalibratedBinaryClassificationTrainer.Options {
new SdcaCalibratedBinaryTrainer.Options {
LabelColumnName = "Sentiment",
FeatureColumnName = "Features",
ConvergenceTolerance = 0.01f, // The learning rate for adjusting bias from being regularized
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public static void Example()
var trainTestData = mlContext.Data.TrainTestSplit(data, testFraction: 0.1);

// Define the trainer options.
var options = new SdcaCalibratedBinaryClassificationTrainer.Options()
var options = new SdcaCalibratedBinaryTrainer.Options()
{
// Make the convergence tolerance tighter.
ConvergenceTolerance = 0.05f,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public static void Example()
// CC 1.216908,1.248052,1.391902,0.4326252,1.099942,0.9262842,1.334019,1.08762,0.9468155,0.4811099
// DD 0.7871246,1.053327,0.8971719,1.588544,1.242697,1.362964,0.6303943,0.9810045,0.9431419,1.557455

var options = new SdcaMulticlassClassificationTrainer.Options
var options = new SdcaMulticlassTrainer.Options
{
// Add custom loss
LossFunction = new HingeLoss(),
Expand Down
8 changes: 4 additions & 4 deletions src/Microsoft.ML.FastTree/FastTreeArguments.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers.FastTree;

[assembly: EntryPointModule(typeof(FastTreeBinaryClassificationTrainer.Options))]
[assembly: EntryPointModule(typeof(FastTreeBinaryTrainer.Options))]
[assembly: EntryPointModule(typeof(FastTreeRegressionTrainer.Options))]
[assembly: EntryPointModule(typeof(FastTreeTweedieTrainer.Options))]
[assembly: EntryPointModule(typeof(FastTreeRankingTrainer.Options))]
Expand Down Expand Up @@ -52,10 +52,10 @@ public enum EarlyStoppingRankingMetric
}

// XML docs are provided in the other part of this partial class. No need to duplicate the content here.
public sealed partial class FastTreeBinaryClassificationTrainer
public sealed partial class FastTreeBinaryTrainer
{
/// <summary>
/// Options for the <see cref="FastTreeBinaryClassificationTrainer"/>.
/// Options for the <see cref="FastTreeBinaryTrainer"/>.
/// </summary>
[TlcModule.Component(Name = LoadNameValue, FriendlyName = UserNameValue, Desc = Summary)]
public sealed class Options : BoostedTreeOptions, IFastTreeTrainerFactory
Expand Down Expand Up @@ -102,7 +102,7 @@ public Options()
EarlyStoppingMetric = EarlyStoppingMetric.L1Norm;
}

ITrainer IComponentFactory<ITrainer>.CreateComponent(IHostEnvironment env) => new FastTreeBinaryClassificationTrainer(env, this);
ITrainer IComponentFactory<ITrainer>.CreateComponent(IHostEnvironment env) => new FastTreeBinaryTrainer(env, this);
}
}

Expand Down
34 changes: 17 additions & 17 deletions src/Microsoft.ML.FastTree/FastTreeClassification.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers.FastTree;

[assembly: LoadableClass(FastTreeBinaryClassificationTrainer.Summary, typeof(FastTreeBinaryClassificationTrainer), typeof(FastTreeBinaryClassificationTrainer.Options),
[assembly: LoadableClass(FastTreeBinaryTrainer.Summary, typeof(FastTreeBinaryTrainer), typeof(FastTreeBinaryTrainer.Options),
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer), typeof(SignatureFeatureScorerTrainer) },
FastTreeBinaryClassificationTrainer.UserNameValue,
FastTreeBinaryClassificationTrainer.LoadNameValue,
FastTreeBinaryTrainer.UserNameValue,
FastTreeBinaryTrainer.LoadNameValue,
"FastTreeClassification",
"FastTree",
"ft",
FastTreeBinaryClassificationTrainer.ShortName,
FastTreeBinaryTrainer.ShortName,

// FastRank names
"FastRankBinaryClassification",
Expand Down Expand Up @@ -101,8 +101,8 @@ private static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoad
/// The <see cref="IEstimator{TTransformer}"/> for training a decision tree binary classification model using FastTree.
/// </summary>
/// <include file='doc.xml' path='doc/members/member[@name="FastTree_remarks"]/*' />
public sealed partial class FastTreeBinaryClassificationTrainer :
BoostingFastTreeTrainerBase<FastTreeBinaryClassificationTrainer.Options,
public sealed partial class FastTreeBinaryTrainer :
BoostingFastTreeTrainerBase<FastTreeBinaryTrainer.Options,
BinaryPredictionTransformer<CalibratedModelParametersBase<FastTreeBinaryModelParameters, PlattCalibrator>>,
CalibratedModelParametersBase<FastTreeBinaryModelParameters, PlattCalibrator>>
{
Expand All @@ -118,7 +118,7 @@ public sealed partial class FastTreeBinaryClassificationTrainer :
private double _sigmoidParameter;

/// <summary>
/// Initializes a new instance of <see cref="FastTreeBinaryClassificationTrainer"/>
/// Initializes a new instance of <see cref="FastTreeBinaryTrainer"/>
/// </summary>
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
/// <param name="labelColumnName">The name of the label column.</param>
Expand All @@ -128,7 +128,7 @@ public sealed partial class FastTreeBinaryClassificationTrainer :
/// <param name="minimumExampleCountPerLeaf">The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data.</param>
/// <param name="numberOfLeaves">The max number of leaves in each regression tree.</param>
/// <param name="numberOfTrees">Total number of decision trees to create in the ensemble.</param>
internal FastTreeBinaryClassificationTrainer(IHostEnvironment env,
internal FastTreeBinaryTrainer(IHostEnvironment env,
string labelColumnName = DefaultColumnNames.Label,
string featureColumnName = DefaultColumnNames.Features,
string exampleWeightColumnName = null,
Expand All @@ -143,11 +143,11 @@ internal FastTreeBinaryClassificationTrainer(IHostEnvironment env,
}

/// <summary>
/// Initializes a new instance of <see cref="FastTreeBinaryClassificationTrainer"/> by using the <see cref="Options"/> class.
/// Initializes a new instance of <see cref="FastTreeBinaryTrainer"/> by using the <see cref="Options"/> class.
/// </summary>
/// <param name="env">The instance of <see cref="IHostEnvironment"/>.</param>
/// <param name="options">Algorithm advanced settings.</param>
internal FastTreeBinaryClassificationTrainer(IHostEnvironment env, Options options)
internal FastTreeBinaryTrainer(IHostEnvironment env, Options options)
: base(env, options, TrainerUtils.MakeBoolScalarLabel(options.LabelColumnName))
{
// Set the sigmoid parameter to the 2 * learning rate, for traditional FastTreeClassification loss
Expand Down Expand Up @@ -278,7 +278,7 @@ private protected override BinaryPredictionTransformer<CalibratedModelParameters
=> new BinaryPredictionTransformer<CalibratedModelParametersBase<FastTreeBinaryModelParameters, PlattCalibrator>>(Host, model, trainSchema, FeatureColumn.Name);

/// <summary>
/// Trains a <see cref="FastTreeBinaryClassificationTrainer"/> using both training and validation data, returns
/// Trains a <see cref="FastTreeBinaryTrainer"/> using both training and validation data, returns
/// a <see cref="BinaryPredictionTransformer{CalibratedModelParametersBase}"/>.
/// </summary>
public BinaryPredictionTransformer<CalibratedModelParametersBase<FastTreeBinaryModelParameters, PlattCalibrator>> Fit(IDataView trainData, IDataView validationData)
Expand Down Expand Up @@ -403,18 +403,18 @@ public void AdjustTreeOutputs(IChannel ch, InternalRegressionTree tree,
internal static partial class FastTree
{
[TlcModule.EntryPoint(Name = "Trainers.FastTreeBinaryClassifier",
Desc = FastTreeBinaryClassificationTrainer.Summary,
UserName = FastTreeBinaryClassificationTrainer.UserNameValue,
ShortName = FastTreeBinaryClassificationTrainer.ShortName)]
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, FastTreeBinaryClassificationTrainer.Options input)
Desc = FastTreeBinaryTrainer.Summary,
UserName = FastTreeBinaryTrainer.UserNameValue,
ShortName = FastTreeBinaryTrainer.ShortName)]
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, FastTreeBinaryTrainer.Options input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("TrainFastTree");
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);

return TrainerEntryPointsUtils.Train<FastTreeBinaryClassificationTrainer.Options, CommonOutputs.BinaryClassificationOutput>(host, input,
() => new FastTreeBinaryClassificationTrainer(host, input),
return TrainerEntryPointsUtils.Train<FastTreeBinaryTrainer.Options, CommonOutputs.BinaryClassificationOutput>(host, input,
() => new FastTreeBinaryTrainer(host, input),
() => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumnName),
() => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.ExampleWeightColumnName),
() => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.RowGroupColumnName));
Expand Down
Loading