Skip to content

FastTreeRankingTrainer expose non-advanced args(#1246) #1393

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion .vsts-dotnet-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,15 @@ phases:

- template: /build/ci/phase-template.yml
parameters:
name: Windows_NT
name: Windows_x64
buildScript: build.cmd
queue:
name: Hosted VS2017

- template: /build/ci/phase-template.yml
parameters:
name: Windows_x86
architecture: x86
buildScript: build.cmd
queue:
name: Hosted VS2017
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ ML.NET allows .NET developers to develop their own models and infuse custom mach

ML.NET was originally developed in Microsoft Research, and evolved into a significant framework over the last decade and is used across many product groups in Microsoft like Windows, Bing, PowerPoint, Excel and more.

With this first preview release, ML.NET enables machine learning tasks like classification (for example: support text classification, sentiment analysis) and regression (for example, price-prediction).
ML.NET enables machine learning tasks like classification (for example: support text classification, sentiment analysis) and regression (for example, price-prediction).

Along with these ML capabilities, this first release of ML.NET also brings the first draft of .NET APIs for training models, using models for predictions, as well as the core components of this framework such as learning algorithms, transforms, and ML data structures.

Expand Down
4 changes: 3 additions & 1 deletion build/ci/phase-template.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
parameters:
name: ''
architecture: x64
buildScript: ''
queue: {}

Expand All @@ -8,6 +9,7 @@ phases:
variables:
_buildScript: ${{ parameters.buildScript }}
_phaseName: ${{ parameters.name }}
_arch: ${{ parameters.architecture }}
queue:
parallel: 99
matrix:
Expand All @@ -17,7 +19,7 @@ phases:
_configuration: Release
${{ insert }}: ${{ parameters.queue }}
steps:
- script: $(_buildScript) -$(_configuration)
- script: $(_buildScript) -$(_configuration) -buildArch=$(_arch)
displayName: Build
- ${{ if eq(parameters.name, 'MacOS') }}:
- script: brew install libomp
Expand Down
2 changes: 1 addition & 1 deletion docs/building/unix-instructions.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Building ML.NET on Linux and macOS
## Building

1. Install the prerequisites ([Linux](#user-content-linux), [macOS](#user-content-macos))
2. Clone the machine learning repo `git clone https://github.com/dotnet/machinelearning.git`
2. Clone the machine learning repo `git clone --recursive https://github.com/dotnet/machinelearning.git`
3. Navigate to the `machinelearning` directory
4. Run the build script `./build.sh`

Expand Down
11 changes: 10 additions & 1 deletion init-tools.cmd
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ set BUILD_TOOLS_PATH=%PACKAGES_DIR%\Microsoft.DotNet.BuildTools\%BUILDTOOLS_VERS
set INIT_TOOLS_RESTORE_PROJECT=%~dp0init-tools.msbuild
set BUILD_TOOLS_SEMAPHORE_DIR=%TOOLRUNTIME_DIR%\%BUILDTOOLS_VERSION%
set BUILD_TOOLS_SEMAPHORE=%BUILD_TOOLS_SEMAPHORE_DIR%\init-tools.completed
set ARCH=x64

:: if force option is specified then clean the tool runtime and build tools package directory to force it to get recreated
if [%1]==[force] (
Expand Down Expand Up @@ -47,9 +48,17 @@ echo Running %0 > "%INIT_TOOLS_LOG%"
set /p DOTNET_VERSION=< "%~dp0DotnetCLIVersion.txt"
if exist "%DOTNET_CMD%" goto :afterdotnetrestore

:Arg_Loop
if [%1] == [] goto :ArchSet
if /i [%1] == [x86] ( set ARCH=x86&&goto ArchSet)
shift
goto :Arg_Loop

:ArchSet

echo Installing dotnet cli...
if NOT exist "%DOTNET_PATH%" mkdir "%DOTNET_PATH%"
set DOTNET_ZIP_NAME=dotnet-sdk-%DOTNET_VERSION%-win-x64.zip
set DOTNET_ZIP_NAME=dotnet-sdk-%DOTNET_VERSION%-win-%ARCH%.zip
set DOTNET_REMOTE_PATH=https://dotnetcli.azureedge.net/dotnet/Sdk/%DOTNET_VERSION%/%DOTNET_ZIP_NAME%
set DOTNET_LOCAL_PATH=%DOTNET_PATH%%DOTNET_ZIP_NAME%
echo Installing '%DOTNET_REMOTE_PATH%' to '%DOTNET_LOCAL_PATH%' >> "%INIT_TOOLS_LOG%"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@
<ItemGroup>
<Content Include="..\common\CommonPackage.props" Pack="true" PackagePath="build\netstandard2.0\$(MSBuildProjectName).props" />
<Content Include="$(PackageAssetsPath)$(PackageIdFolderName)\LICENSE.txt" Pack="true" PackagePath=".\" />
<Content Include="$(PackageAssetsPath)$(PackageIdFolderName)\THIRD_PARTY_NOTICES.txt" Pack="true" PackagePath=".\" />
<Content Condition="Exists('$(PackageAssetsPath)$(PackageIdFolderName)\THIRD_PARTY_NOTICES.txt')" Include="$(PackageAssetsPath)$(PackageIdFolderName)\THIRD_PARTY_NOTICES.txt" Pack="true" PackagePath=".\" />
</ItemGroup>
</Project>
2 changes: 1 addition & 1 deletion run.cmd
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ set DOTNET_SKIP_FIRST_TIME_EXPERIENCE=1
set DOTNET_MULTILEVEL_LOOKUP=0

:: Restore the Tools directory
call "%~dp0init-tools.cmd"
call "%~dp0init-tools.cmd" %*
if NOT [%ERRORLEVEL%]==[0] exit /b 1

set _toolRuntime=%~dp0Tools
Expand Down
37 changes: 35 additions & 2 deletions src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ public abstract class TrainerEstimatorBase<TTransformer, TModel> : ITrainerEstim

public abstract PredictionKind PredictionKind { get; }

public TrainerEstimatorBase(IHost host, SchemaShape.Column feature, SchemaShape.Column label, SchemaShape.Column weight = null)
public TrainerEstimatorBase(IHost host,
SchemaShape.Column feature,
SchemaShape.Column label,
SchemaShape.Column weight = null)
{
Contracts.CheckValue(host, nameof(host));
Host = host;
Expand Down Expand Up @@ -149,9 +152,39 @@ protected TTransformer TrainTransformer(IDataView trainSet,

protected abstract TTransformer MakeTransformer(TModel model, Schema trainSchema);

private RoleMappedData MakeRoles(IDataView data) =>
protected virtual RoleMappedData MakeRoles(IDataView data) =>
new RoleMappedData(data, label: LabelColumn?.Name, feature: FeatureColumn.Name, weight: WeightColumn?.Name);

IPredictor ITrainer.Train(TrainContext context) => Train(context);
}

/// <summary>
/// This represents a basic class for 'simple trainer'.
/// A 'simple trainer' accepts one feature column and one label column, also optionally a weight column.
/// It produces a 'prediction transformer'.
/// </summary>
public abstract class TrainerEstimatorBaseWithGroupId<TTransformer, TModel> : TrainerEstimatorBase<TTransformer, TModel>
where TTransformer : ISingleFeaturePredictionTransformer<TModel>
where TModel : IPredictor
{
/// <summary>
/// The optional groupID column that the ranking trainers expects.
/// </summary>
public readonly SchemaShape.Column GroupIdColumn;

public TrainerEstimatorBaseWithGroupId(IHost host,
SchemaShape.Column feature,
SchemaShape.Column label,
SchemaShape.Column weight = null,
SchemaShape.Column groupId = null)
:base(host, feature, label, weight)
{
Host.CheckValueOrNull(groupId);
GroupIdColumn = groupId;
}

protected override RoleMappedData MakeRoles(IDataView data) =>
new RoleMappedData(data, label: LabelColumn?.Name, feature: FeatureColumn.Name, group: GroupIdColumn?.Name, weight: WeightColumn?.Name);

}
}
73 changes: 11 additions & 62 deletions src/Microsoft.ML.Data/Training/TrainerUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -362,9 +362,14 @@ public static SchemaShape.Column MakeR4ScalarLabel(string labelColumn)
/// <summary>
/// The <see cref="SchemaShape.Column"/> for the label column for regression tasks.
/// </summary>
/// <param name="labelColumn">name of the weight column</param>
public static SchemaShape.Column MakeU4ScalarLabel(string labelColumn)
=> new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true);
/// <param name="columnName">name of the weight column</param>
public static SchemaShape.Column MakeU4ScalarColumn(string columnName)
{
if (columnName == null)
return null;

return new SchemaShape.Column(columnName, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true);
}

/// <summary>
/// The <see cref="SchemaShape.Column"/> for the feature column.
Expand All @@ -377,69 +382,13 @@ public static SchemaShape.Column MakeR4VecFeature(string featureColumn)
/// The <see cref="SchemaShape.Column"/> for the weight column.
/// </summary>
/// <param name="weightColumn">name of the weight column</param>
public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn)
/// <param name="isExplicit">whether the column is implicitly, or explicitly defined</param>
public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn, bool isExplicit = true)
{
if (weightColumn == null)
if (weightColumn == null || !isExplicit)
return null;
return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
}

private static void CheckArgColName(IHostEnvironment host, string defaultColName, string argValue)
{
if (argValue != defaultColName)
throw host.Except($"Don't supply a value for the {defaultColName} column in the arguments, as it will be ignored. Specify them in the loader, or constructor instead instead.");
}

/// <summary>
/// Check that the label, feature, weights, groupId column names are not supplied in the args of the constructor, through the advancedSettings parameter,
/// for cases when the public constructor is called.
/// The recommendation is to set the column names directly.
/// </summary>
public static void CheckArgsHaveDefaultColNames(IHostEnvironment host, LearnerInputBaseWithGroupId args)
{
// check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly
CheckArgColName(host, DefaultColumnNames.Label, args.LabelColumn);
CheckArgColName(host, DefaultColumnNames.Features, args.FeatureColumn);
CheckArgColName(host, DefaultColumnNames.Weight, args.WeightColumn);

if (args.GroupIdColumn != null)
CheckArgColName(host, DefaultColumnNames.GroupId, args.GroupIdColumn);
}

/// <summary>
/// Check that the label, feature, and weights column names are not supplied in the args of the constructor, through the advancedSettings parameter,
/// for cases when the public constructor is called.
/// The recommendation is to set the column names directly.
/// </summary>
public static void CheckArgsHaveDefaultColNames(IHostEnvironment host, LearnerInputBaseWithWeight args)
{
// check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly
CheckArgColName(host, DefaultColumnNames.Label, args.LabelColumn);
CheckArgColName(host, DefaultColumnNames.Features, args.FeatureColumn);
CheckArgColName(host, DefaultColumnNames.Weight, args.WeightColumn);
}

/// <summary>
/// Check that the label and feature column names are not supplied in the args of the constructor, through the advancedSettings parameter,
/// for cases when the public constructor is called.
/// The recommendation is to set the column names directly.
/// </summary>
public static void CheckArgsHaveDefaultColNames(IHostEnvironment host, LearnerInputBaseWithLabel args)
{
// check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly
CheckArgColName(host, DefaultColumnNames.Label, args.LabelColumn);
CheckArgColName(host, DefaultColumnNames.Features, args.FeatureColumn);
}

/// <summary>
/// If, after applying the advancedArgs delegate, the args are different that the default value
/// and are also different than the value supplied directly to the xtension method, warn the user.
/// </summary>
public static void CheckArgsAndAdvancedSettingMismatch<T>(IChannel channel, T methodParam, T defaultVal, T setting, string argName)
{
if (!setting.Equals(defaultVal) && !setting.Equals(methodParam))
channel.Warning($"The value supplied to advanced settings , is different than the value supplied directly. Using value {setting} for {argName}");
}
}

/// <summary>
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Ensemble/Microsoft.ML.Ensemble.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<IncludeInPackage>Microsoft.ML.Ensemble</IncludeInPackage>
<IncludeInPackage>Microsoft.ML</IncludeInPackage>
<DefineConstants>CORECLR</DefineConstants>
</PropertyGroup>

Expand Down
20 changes: 17 additions & 3 deletions src/Microsoft.ML.FastTree/BoostingFastTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,24 @@ protected BoostingFastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaSh
{
}

protected BoostingFastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn,
string weightColumn = null, string groupIdColumn = null, Action<TArgs> advancedSettings = null)
: base(env, label, featureColumn, weightColumn, groupIdColumn, advancedSettings)
protected BoostingFastTreeTrainerBase(IHostEnvironment env,
SchemaShape.Column label,
string featureColumn,
string weightColumn,
string groupIdColumn,
int numLeaves,
int numTrees,
int minDocumentsInLeafs,
double learningRate,
Action<TArgs> advancedSettings)
: base(env, label, featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDocumentsInLeafs, advancedSettings)
{

if (Args.LearningRates != learningRate)
{
using (var ch = Host.Start($"Setting learning rate to: {learningRate} as supplied in the direct arguments."))
Args.LearningRates = learningRate;
}
}

protected override void CheckArgs(IChannel ch)
Expand Down
57 changes: 21 additions & 36 deletions src/Microsoft.ML.FastTree/FastTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Data.Conversion;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Calibration;
using Microsoft.ML.Runtime.Internal.Internallearn;
using Microsoft.ML.Runtime.Internal.Utilities;
Expand Down Expand Up @@ -45,7 +46,7 @@ internal static class FastTreeShared
}

public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
TrainerEstimatorBase<TTransformer, TModel>
TrainerEstimatorBaseWithGroupId<TTransformer, TModel>
where TTransformer: ISingleFeaturePredictionTransformer<TModel>
where TArgs : TreeArgs, new()
where TModel : IPredictorProducing<Float>
Expand Down Expand Up @@ -92,26 +93,36 @@ public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
/// <summary>
/// Constructor to use when instantiating the classes deriving from here through the API.
/// </summary>
private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn,
string weightColumn = null, string groupIdColumn = null, Action<TArgs> advancedSettings = null)
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn))
private protected FastTreeTrainerBase(IHostEnvironment env,
SchemaShape.Column label,
string featureColumn,
string weightColumn,
string groupIdColumn,
int numLeaves,
int numTrees,
int minDocumentsInLeafs,
Action<TArgs> advancedSettings)
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), TrainerUtils.MakeU4ScalarColumn(groupIdColumn))
{
Args = new TArgs();

// set up the directly provided values
// override with the directly provided values.
Args.NumLeaves = numLeaves;
Args.NumTrees = numTrees;
Args.MinDocumentsInLeafs = minDocumentsInLeafs;

//apply the advanced args, if the user supplied any
advancedSettings?.Invoke(Args);

// check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly
TrainerUtils.CheckArgsHaveDefaultColNames(Host, Args);

Args.LabelColumn = label.Name;
Args.FeatureColumn = featureColumn;

if (weightColumn != null)
Args.WeightColumn = weightColumn;
Args.WeightColumn = Optional<string>.Explicit(weightColumn); ;

if (groupIdColumn != null)
Args.GroupIdColumn = groupIdColumn;
Args.GroupIdColumn = Optional<string>.Explicit(groupIdColumn); ;

// The discretization step renders this trainer non-parametric, and therefore it does not need normalization.
// Also since it builds its own internal discretized columnar structures, it cannot benefit from caching.
Expand All @@ -128,7 +139,7 @@ private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column l
/// Legacy constructor that is used when invoking the classes deriving from this, through maml.
/// </summary>
private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label)
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn))
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit))
{
Host.CheckValue(args, nameof(args));
Args = args;
Expand Down Expand Up @@ -159,32 +170,6 @@ protected virtual Float GetMaxLabel()
return Float.PositiveInfinity;
}

/// <summary>
/// If, after applying the advancedSettings delegate, the args are different that the default value
/// and are also different than the value supplied directly to the xtension method, warn the user
/// about which value is being used.
/// The parameters that appear here, numTrees, minDocumentsInLeafs, numLeaves, learningRate are the ones the users are most likely to tune.
/// This list should follow the one in the constructor, and the extension methods on the <see cref="TrainContextBase"/>.
/// REVIEW: we should somehow mark the arguments that are set apart in those two places. Currently they stand out by their sort order annotation.
/// </summary>
protected void CheckArgsAndAdvancedSettingMismatch(int numLeaves,
int numTrees,
int minDocumentsInLeafs,
double learningRate,
BoostedTreeArgs snapshot,
BoostedTreeArgs currentArgs)
{
using (var ch = Host.Start("Comparing advanced settings with the directly provided values."))
{

// Check that the user didn't supply different parameters in the args, from what it specified directly.
TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, numLeaves, snapshot.NumLeaves, currentArgs.NumLeaves, nameof(numLeaves));
TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, numTrees, snapshot.NumTrees, currentArgs.NumTrees, nameof(numTrees));
TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, minDocumentsInLeafs, snapshot.MinDocumentsInLeafs, currentArgs.MinDocumentsInLeafs, nameof(minDocumentsInLeafs));
TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, learningRate, snapshot.LearningRates, currentArgs.LearningRates, nameof(learningRate));
}
}

private void Initialize(IHostEnvironment env)
{
int numThreads = Args.NumThreads ?? Environment.ProcessorCount;
Expand Down
Loading