Skip to content

Modify API for advanced settings. (FastTree, RandomForest) #2047

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jan 11, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/Microsoft.ML.Data/EntryPoints/InputBase.cs
Original file line number Diff line number Diff line change
@@ -35,15 +35,27 @@ public enum CachingOptions
[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInput))]
public abstract class LearnerInputBase
{
/// <summary>
/// The data to be used for training.
/// </summary>
[Argument(ArgumentType.Required, ShortName = "data", HelpText = "The data to be used for training", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public IDataView TrainingData;

/// <summary>
/// Column to use for features.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for features", ShortName = "feat", SortOrder = 2, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public string FeatureColumn = DefaultColumnNames.Features;

/// <summary>
/// Normalize option for the feature column.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Normalize option for the feature column", ShortName = "norm", SortOrder = 5, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public NormalizeOption NormalizeFeatures = NormalizeOption.Auto;

/// <summary>
/// Whether learner should cache input training data.
/// </summary>
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Whether learner should cache input training data", ShortName = "cache", SortOrder = 6, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public CachingOptions Caching = CachingOptions.Auto;
}
@@ -54,6 +66,9 @@ public abstract class LearnerInputBase
[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithLabel))]
public abstract class LearnerInputBaseWithLabel : LearnerInputBase
{
/// <summary>
/// Column to use for labels.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for labels", ShortName = "lab", SortOrder = 3, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public string LabelColumn = DefaultColumnNames.Label;
}
@@ -65,6 +80,9 @@ public abstract class LearnerInputBaseWithLabel : LearnerInputBase
[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithWeight))]
public abstract class LearnerInputBaseWithWeight : LearnerInputBaseWithLabel
{
/// <summary>
/// Column to use for example weight.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 4, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public Optional<string> WeightColumn = Optional<string>.Implicit(DefaultColumnNames.Weight);
}
@@ -95,6 +113,9 @@ public abstract class EvaluateInputBase
[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithGroupId))]
public abstract class LearnerInputBaseWithGroupId : LearnerInputBaseWithWeight
{
/// <summary>
/// Column to use for example groupId.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example groupId", ShortName = "groupId", SortOrder = 5, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public Optional<string> GroupIdColumn = Optional<string>.Implicit(DefaultColumnNames.GroupId);
}
12 changes: 3 additions & 9 deletions src/Microsoft.ML.FastTree/BoostingFastTree.cs
Original file line number Diff line number Diff line change
@@ -28,16 +28,10 @@ protected BoostingFastTreeTrainerBase(IHostEnvironment env,
int numLeaves,
int numTrees,
int minDatapointsInLeaves,
double learningRate,
Action<TArgs> advancedSettings)
: base(env, label, featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDatapointsInLeaves, advancedSettings)
double learningRate)
: base(env, label, featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDatapointsInLeaves)
{

if (Args.LearningRates != learningRate)
{
using (var ch = Host.Start($"Setting learning rate to: {learningRate} as supplied in the direct arguments."))
Args.LearningRates = learningRate;
}
Args.LearningRates = learningRate;
}

protected override void CheckArgs(IChannel ch)
8 changes: 2 additions & 6 deletions src/Microsoft.ML.FastTree/FastTree.cs
Original file line number Diff line number Diff line change
@@ -114,8 +114,7 @@ private protected FastTreeTrainerBase(IHostEnvironment env,
string groupIdColumn,
int numLeaves,
int numTrees,
int minDatapointsInLeaves,
Action<TArgs> advancedSettings)
int minDatapointsInLeaves)
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), TrainerUtils.MakeU4ScalarColumn(groupIdColumn))
{
Args = new TArgs();
@@ -126,9 +125,6 @@ private protected FastTreeTrainerBase(IHostEnvironment env,
Args.NumTrees = numTrees;
Args.MinDocumentsInLeafs = minDatapointsInLeaves;

//apply the advanced args, if the user supplied any
advancedSettings?.Invoke(Args);

Args.LabelColumn = label.Name;
Args.FeatureColumn = featureColumn;

@@ -152,7 +148,7 @@ private protected FastTreeTrainerBase(IHostEnvironment env,
}

/// <summary>
/// Legacy constructor that is used when invoking the classes deriving from this, through maml.
/// Constructor that is used when invoking the classes deriving from this, through maml.
/// </summary>
private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label)
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit))
Loading