-
Notifications
You must be signed in to change notification settings - Fork 1.9k
TrainTestSplit function #1005
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TrainTestSplit function #1005
Changes from all commits
7f04b87
260aff6
ab76c3d
6c75e84
227f38a
adaa9da
985fe7d
106a6a0
fbee542
ca10eb5
7569e08
f586984
e61ca70
e56ce30
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,9 +2,15 @@ | |
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.ML.Core.Data; | ||
using Microsoft.ML.Runtime; | ||
using Microsoft.ML.Runtime.Data; | ||
using Microsoft.ML.Transforms; | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
|
||
namespace Microsoft.ML.Runtime.Training | ||
namespace Microsoft.ML | ||
{ | ||
/// <summary> | ||
/// A training context is an object instantiable by a user to do various tasks relating to a particular | ||
|
@@ -16,13 +22,136 @@ public abstract class TrainContextBase | |
protected readonly IHost Host; | ||
internal IHostEnvironment Environment => Host; | ||
|
||
/// <summary> | ||
/// Split the dataset into the train set and test set according to the given fraction. | ||
/// Respects the <paramref name="stratificationColumn"/> if provided. | ||
/// </summary> | ||
/// <param name="data">The dataset to split.</param> | ||
/// <param name="testFraction">The fraction of data to go into the test set.</param> | ||
/// <param name="stratificationColumn">Optional stratification column.</param> | ||
/// <remarks>If two examples share the same value of the <paramref name="stratificationColumn"/> (if provided), | ||
/// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from | ||
/// train to the test set.</remarks> | ||
/// <returns>A pair of datasets, for the train and test set.</returns> | ||
public (IDataView trainSet, IDataView testSet) TrainTestSplit(IDataView data, double testFraction = 0.1, string stratificationColumn = null) | ||
{ | ||
Host.CheckValue(data, nameof(data)); | ||
Host.CheckParam(0 < testFraction && testFraction < 1, nameof(testFraction), "Must be between 0 and 1 exclusive"); | ||
Host.CheckValueOrNull(stratificationColumn); | ||
|
||
EnsureStratificationColumn(ref data, ref stratificationColumn); | ||
|
||
var trainFilter = new RangeFilter(Host, new RangeFilter.Arguments() | ||
{ | ||
Column = stratificationColumn, | ||
Min = 0, | ||
Max = testFraction, | ||
Complement = true | ||
}, data); | ||
var testFilter = new RangeFilter(Host, new RangeFilter.Arguments() | ||
{ | ||
Column = stratificationColumn, | ||
Min = 0, | ||
Max = testFraction, | ||
Complement = false | ||
}, data); | ||
|
||
return (trainFilter, testFilter); | ||
} | ||
|
||
/// <summary> | ||
/// Train the <paramref name="estimator"/> on <paramref name="numFolds"/> folds of the data sequentially. | ||
/// Return each model and each scored test dataset. | ||
/// </summary> | ||
protected (IDataView scoredTestSet, ITransformer model)[] CrossValidateTrain(IDataView data, IEstimator<ITransformer> estimator, | ||
int numFolds, string stratificationColumn) | ||
{ | ||
Host.CheckValue(data, nameof(data)); | ||
Host.CheckValue(estimator, nameof(estimator)); | ||
Host.CheckParam(numFolds > 1, nameof(numFolds), "Must be more than 1"); | ||
Host.CheckValueOrNull(stratificationColumn); | ||
|
||
EnsureStratificationColumn(ref data, ref stratificationColumn); | ||
|
||
Func<int, (IDataView scores, ITransformer model)> foldFunction = | ||
fold => | ||
{ | ||
var trainFilter = new RangeFilter(Host, new RangeFilter.Arguments | ||
{ | ||
Column = stratificationColumn, | ||
Min = (double)fold / numFolds, | ||
Max = (double)(fold + 1) / numFolds, | ||
Complement = true | ||
}, data); | ||
var testFilter = new RangeFilter(Host, new RangeFilter.Arguments | ||
{ | ||
Column = stratificationColumn, | ||
Min = (double)fold / numFolds, | ||
Max = (double)(fold + 1) / numFolds, | ||
Complement = false | ||
}, data); | ||
|
||
var model = estimator.Fit(trainFilter); | ||
var scoredTest = model.Transform(testFilter); | ||
return (scoredTest, model); | ||
}; | ||
|
||
// Sequential per-fold training. | ||
// REVIEW: we could have a parallel implementation here. We would need to | ||
// spawn off a separate host per fold in that case. | ||
var result = new List<(IDataView scores, ITransformer model)>(); | ||
for (int fold = 0; fold < numFolds; fold++) | ||
result.Add(foldFunction(fold)); | ||
|
||
return result.ToArray(); | ||
} | ||
|
||
protected TrainContextBase(IHostEnvironment env, string registrationName) | ||
{ | ||
Contracts.CheckValue(env, nameof(env)); | ||
env.CheckNonEmpty(registrationName, nameof(registrationName)); | ||
Host = env.Register(registrationName); | ||
} | ||
|
||
/// <summary> | ||
/// Make sure the provided <paramref name="stratificationColumn"/> is valid | ||
/// for <see cref="RangeFilter"/>, hash it if needed, or introduce a new one | ||
/// if needed. | ||
/// </summary> | ||
private void EnsureStratificationColumn(ref IDataView data, ref string stratificationColumn) | ||
{ | ||
// We need to handle two cases: if the stratification column is provided, we use hashJoin to | ||
// build a single hash of it. If it is not, we generate a random number. | ||
|
||
if (stratificationColumn == null) | ||
{ | ||
stratificationColumn = data.Schema.GetTempColumnName("StratificationColumn"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
is there a DefaultColumnName for this? #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I looked, there is none. I don't think there should be either, it's a pretty peculiar concept, unique to CV and split scenarios In reply to: 220372046 [](ancestors = 220372046) |
||
data = new GenerateNumberTransform(Host, data, stratificationColumn); | ||
} | ||
else | ||
{ | ||
if (!data.Schema.TryGetColumnIndex(stratificationColumn, out int stratCol)) | ||
throw Host.ExceptSchemaMismatch(nameof(stratificationColumn), "stratification", stratificationColumn); | ||
|
||
var type = data.Schema.GetColumnType(stratCol); | ||
if (!RangeFilter.IsValidRangeFilterColumnType(Host, type)) | ||
{ | ||
// Hash the stratification column. | ||
// REVIEW: this could currently crash, since Hash only accepts a limited set | ||
// of column types. It used to be HashJoin, but we should probably extend Hash | ||
// instead of having two hash transformations. | ||
var origStratCol = stratificationColumn; | ||
int tmp; | ||
int inc = 0; | ||
|
||
// Generate a new column with the hashed stratification column. | ||
while (data.Schema.TryGetColumnIndex(stratificationColumn, out tmp)) | ||
stratificationColumn = string.Format("{0}_{1:000}", origStratCol, ++inc); | ||
data = new HashEstimator(Host, origStratCol, stratificationColumn, 30).Fit(data).Transform(data); | ||
} | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// Subclasses of <see cref="TrainContext"/> will provide little "extension method" hookable objects | ||
/// (e.g., something like <see cref="BinaryClassificationContext.Trainers"/>). User code will only | ||
|
@@ -140,6 +269,50 @@ public BinaryClassifierEvaluator.Result EvaluateNonCalibrated(IDataView data, st | |
var eval = new BinaryClassifierEvaluator(Host, new BinaryClassifierEvaluator.Arguments() { }); | ||
return eval.Evaluate(data, label, score, predictedLabel); | ||
} | ||
|
||
/// <summary> | ||
/// Run cross-validation over <paramref name="numFolds"/> folds of <paramref name="data"/>, by fitting <paramref name="estimator"/>, | ||
/// and respecting <paramref name="stratificationColumn"/> if provided. | ||
/// Then evaluate each sub-model against <paramref name="labelColumn"/> and return metrics. | ||
/// </summary> | ||
/// <param name="data">The data to run cross-validation on.</param> | ||
/// <param name="estimator">The estimator to fit.</param> | ||
/// <param name="numFolds">Number of cross-validation folds.</param> | ||
/// <param name="labelColumn">The label column (for evaluation).</param> | ||
/// <param name="stratificationColumn">Optional stratification column.</param> | ||
/// <remarks>If two examples share the same value of the <paramref name="stratificationColumn"/> (if provided), | ||
/// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from | ||
/// train to the test set.</remarks> | ||
/// <returns>Per-fold results: metrics, models, scored datasets.</returns> | ||
public (BinaryClassifierEvaluator.Result metrics, ITransformer model, IDataView scoredTestData)[] CrossValidateNonCalibrated( | ||
IDataView data, IEstimator<ITransformer> estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null) | ||
{ | ||
Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); | ||
var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn); | ||
return result.Select(x => (EvaluateNonCalibrated(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); | ||
} | ||
|
||
/// <summary> | ||
/// Run cross-validation over <paramref name="numFolds"/> folds of <paramref name="data"/>, by fitting <paramref name="estimator"/>, | ||
/// and respecting <paramref name="stratificationColumn"/> if provided. | ||
/// Then evaluate each sub-model against <paramref name="labelColumn"/> and return metrics. | ||
/// </summary> | ||
/// <param name="data">The data to run cross-validation on.</param> | ||
/// <param name="estimator">The estimator to fit.</param> | ||
/// <param name="numFolds">Number of cross-validation folds.</param> | ||
/// <param name="labelColumn">The label column (for evaluation).</param> | ||
/// <param name="stratificationColumn">Optional stratification column.</param> | ||
/// <remarks>If two examples share the same value of the <paramref name="stratificationColumn"/> (if provided), | ||
/// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from | ||
/// train to the test set.</remarks> | ||
/// <returns>Per-fold results: metrics, models, scored datasets.</returns> | ||
public (BinaryClassifierEvaluator.CalibratedResult metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( | ||
IDataView data, IEstimator<ITransformer> estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null) | ||
{ | ||
Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); | ||
var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn); | ||
return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); | ||
} | ||
} | ||
|
||
/// <summary> | ||
|
@@ -191,6 +364,28 @@ public MultiClassClassifierEvaluator.Result Evaluate(IDataView data, string labe | |
var eval = new MultiClassClassifierEvaluator(Host, args); | ||
return eval.Evaluate(data, label, score, predictedLabel); | ||
} | ||
|
||
/// <summary> | ||
/// Run cross-validation over <paramref name="numFolds"/> folds of <paramref name="data"/>, by fitting <paramref name="estimator"/>, | ||
/// and respecting <paramref name="stratificationColumn"/> if provided. | ||
/// Then evaluate each sub-model against <paramref name="labelColumn"/> and return metrics. | ||
/// </summary> | ||
/// <param name="data">The data to run cross-validation on.</param> | ||
/// <param name="estimator">The estimator to fit.</param> | ||
/// <param name="numFolds">Number of cross-validation folds.</param> | ||
/// <param name="labelColumn">The label column (for evaluation).</param> | ||
/// <param name="stratificationColumn">Optional stratification column.</param> | ||
/// <remarks>If two examples share the same value of the <paramref name="stratificationColumn"/> (if provided), | ||
/// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from | ||
/// train to the test set.</remarks> | ||
/// <returns>Per-fold results: metrics, models, scored datasets.</returns> | ||
public (MultiClassClassifierEvaluator.Result metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( | ||
IDataView data, IEstimator<ITransformer> estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null) | ||
{ | ||
Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); | ||
var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn); | ||
return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); | ||
} | ||
} | ||
|
||
/// <summary> | ||
|
@@ -233,5 +428,27 @@ public RegressionEvaluator.Result Evaluate(IDataView data, string label, string | |
var eval = new RegressionEvaluator(Host, new RegressionEvaluator.Arguments() { }); | ||
return eval.Evaluate(data, label, score); | ||
} | ||
|
||
/// <summary> | ||
/// Run cross-validation over <paramref name="numFolds"/> folds of <paramref name="data"/>, by fitting <paramref name="estimator"/>, | ||
/// and respecting <paramref name="stratificationColumn"/> if provided. | ||
/// Then evaluate each sub-model against <paramref name="labelColumn"/> and return metrics. | ||
/// </summary> | ||
/// <param name="data">The data to run cross-validation on.</param> | ||
/// <param name="estimator">The estimator to fit.</param> | ||
/// <param name="numFolds">Number of cross-validation folds.</param> | ||
/// <param name="labelColumn">The label column (for evaluation).</param> | ||
/// <param name="stratificationColumn">Optional stratification column.</param> | ||
/// <remarks>If two examples share the same value of the <paramref name="stratificationColumn"/> (if provided), | ||
/// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from | ||
/// train to the test set.</remarks> | ||
/// <returns>Per-fold results: metrics, models, scored datasets.</returns> | ||
public (RegressionEvaluator.Result metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( | ||
IDataView data, IEstimator<ITransformer> estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null) | ||
{ | ||
Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); | ||
var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn); | ||
return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious, why not return a list? #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We tend to avoid returning mutable collections.
In reply to: 220371873 [](ancestors = 220371873)