-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Get rid of value tuples in TrainTest and CrossValidation #2507
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,31 @@ public abstract class TrainCatalogBase | |
[BestFriend] | ||
internal IHostEnvironment Environment => Host; | ||
|
||
/// <summary> | ||
/// A pair of datasets, for the train and test set. | ||
/// </summary> | ||
public struct TrainTestData | ||
{ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I could see us wanting to expand this in the future to also include a validation set, so it could be prudent to keep the name a bit vague. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What would be summary for PartitionedData? In reply to: 255725077 [](ancestors = 255725077) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
/// <summary> | ||
/// Training set. | ||
/// </summary> | ||
public readonly IDataView TrainSet; | ||
/// <summary> | ||
/// Testing set. | ||
/// </summary> | ||
public readonly IDataView TestSet; | ||
/// <summary> | ||
/// Create pair of datasets. | ||
/// </summary> | ||
/// <param name="trainSet">Training set.</param> | ||
/// <param name="testSet">Testing set.</param> | ||
internal TrainTestData(IDataView trainSet, IDataView testSet) | ||
{ | ||
TrainSet = trainSet; | ||
TestSet = testSet; | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// Split the dataset into the train set and test set according to the given fraction. | ||
/// Respects the <paramref name="stratificationColumn"/> if provided. | ||
|
@@ -37,8 +62,7 @@ public abstract class TrainCatalogBase | |
/// <param name="seed">Optional parameter used in combination with the <paramref name="stratificationColumn"/>. | ||
/// If the <paramref name="stratificationColumn"/> is not provided, the random numbers generated to create it, will use this seed as value. | ||
/// And if it is not provided, the default value will be used.</param> | ||
/// <returns>A pair of datasets, for the train and test set.</returns> | ||
public (IDataView trainSet, IDataView testSet) TrainTestSplit(IDataView data, double testFraction = 0.1, string stratificationColumn = null, uint? seed = null) | ||
public TrainTestData TrainTestSplit(IDataView data, double testFraction = 0.1, string stratificationColumn = null, uint? seed = null) | ||
{ | ||
Host.CheckValue(data, nameof(data)); | ||
Host.CheckParam(0 < testFraction && testFraction < 1, nameof(testFraction), "Must be between 0 and 1 exclusive"); | ||
|
@@ -61,14 +85,71 @@ public abstract class TrainCatalogBase | |
Complement = false | ||
}, data); | ||
|
||
return (trainFilter, testFilter); | ||
return new TrainTestData(trainFilter, testFilter); | ||
} | ||
|
||
/// <summary> | ||
/// Results for specific cross-validation fold. | ||
/// </summary> | ||
protected internal struct CrossValidationResult | ||
{ | ||
/// <summary> | ||
/// Model trained during cross validation fold. | ||
/// </summary> | ||
public readonly ITransformer Model; | ||
/// <summary> | ||
/// Scored test set with <see cref="Model"/> for this fold. | ||
/// </summary> | ||
public readonly IDataView Scores; | ||
/// <summary> | ||
/// Fold number. | ||
/// </summary> | ||
public readonly int Fold; | ||
|
||
public CrossValidationResult(ITransformer model, IDataView scores, int fold) | ||
{ | ||
Model = model; | ||
Scores = scores; | ||
Fold = fold; | ||
} | ||
} | ||
/// <summary> | ||
/// Results of running cross-validation. | ||
/// </summary> | ||
/// <typeparam name="T">Type of metric class.</typeparam> | ||
public sealed class CrossValidationResult<T> where T : class | ||
{ | ||
/// <summary> | ||
/// Metrics for this cross-validation fold. | ||
/// </summary> | ||
public readonly T Metrics; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Part of me want to create interface and mark with it all Metric classes. Another part of me remembers what we try get rid of all empty interfaces. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
/// <summary> | ||
/// Model trained during cross-validation fold. | ||
/// </summary> | ||
public readonly ITransformer Model; | ||
/// <summary> | ||
/// The scored hold-out set for this fold. | ||
/// </summary> | ||
public readonly IDataView ScoredHoldOutSet; | ||
/// <summary> | ||
/// Fold number. | ||
/// </summary> | ||
public readonly int Fold; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Is this necessary, since they will be returned in an (ordered) array? And it'll be confusing if the order of the array doesn't match the fold number. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it necessary ?- no, is it nice to have? - I think so. As soon as you tear them away from array, and let's say you want to report specific folds and do some stuff on them, you need to pass around fold anyway.
In reply to: 255727531 [](ancestors = 255727531) |
||
|
||
internal CrossValidationResult(ITransformer model, T metrics, IDataView scores, int fold) | ||
{ | ||
Model = model; | ||
Metrics = metrics; | ||
ScoredHoldOutSet = scores; | ||
Fold = fold; | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// Train the <paramref name="estimator"/> on <paramref name="numFolds"/> folds of the data sequentially. | ||
/// Return each model and each scored test dataset. | ||
/// </summary> | ||
protected internal (IDataView scoredTestSet, ITransformer model)[] CrossValidateTrain(IDataView data, IEstimator<ITransformer> estimator, | ||
protected internal CrossValidationResult[] CrossValidateTrain(IDataView data, IEstimator<ITransformer> estimator, | ||
int numFolds, string stratificationColumn, uint? seed = null) | ||
{ | ||
Host.CheckValue(data, nameof(data)); | ||
|
@@ -78,7 +159,7 @@ protected internal (IDataView scoredTestSet, ITransformer model)[] CrossValidate | |
|
||
EnsureStratificationColumn(ref data, ref stratificationColumn, seed); | ||
|
||
Func<int, (IDataView scores, ITransformer model)> foldFunction = | ||
Func<int, CrossValidationResult> foldFunction = | ||
fold => | ||
{ | ||
var trainFilter = new RangeFilter(Host, new RangeFilter.Options | ||
|
@@ -98,17 +179,17 @@ protected internal (IDataView scoredTestSet, ITransformer model)[] CrossValidate | |
|
||
var model = estimator.Fit(trainFilter); | ||
var scoredTest = model.Transform(testFilter); | ||
return (scoredTest, model); | ||
return new CrossValidationResult(model, scoredTest, fold); | ||
}; | ||
|
||
// Sequential per-fold training. | ||
// REVIEW: we could have a parallel implementation here. We would need to | ||
// spawn off a separate host per fold in that case. | ||
var result = new List<(IDataView scores, ITransformer model)>(); | ||
var result = new CrossValidationResult[numFolds]; | ||
for (int fold = 0; fold < numFolds; fold++) | ||
result.Add(foldFunction(fold)); | ||
result[fold] = foldFunction(fold); | ||
|
||
return result.ToArray(); | ||
return result; | ||
} | ||
|
||
protected internal TrainCatalogBase(IHostEnvironment env, string registrationName) | ||
|
@@ -263,13 +344,14 @@ public BinaryClassificationMetrics EvaluateNonCalibrated(IDataView data, string | |
/// If the <paramref name="stratificationColumn"/> is not provided, the random numbers generated to create it, will use this seed as value. | ||
/// And if it is not provided, the default value will be used.</param> | ||
/// <returns>Per-fold results: metrics, models, scored datasets.</returns> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
maybe<see cref="CrossValidationResult" |
||
public (BinaryClassificationMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidateNonCalibrated( | ||
public CrossValidationResult<BinaryClassificationMetrics>[] CrossValidateNonCalibrated( | ||
IDataView data, IEstimator<ITransformer> estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, | ||
string stratificationColumn = null, uint? seed = null) | ||
{ | ||
Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); | ||
var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed); | ||
return result.Select(x => (EvaluateNonCalibrated(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); | ||
return result.Select(x => new CrossValidationResult<BinaryClassificationMetrics>(x.Model, | ||
EvaluateNonCalibrated(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray(); | ||
} | ||
|
||
/// <summary> | ||
|
@@ -287,13 +369,14 @@ public BinaryClassificationMetrics EvaluateNonCalibrated(IDataView data, string | |
/// train to the test set.</remarks> | ||
/// <param name="seed">If <paramref name="stratificationColumn"/> not present in dataset we will generate random filled column based on provided <paramref name="seed"/>.</param> | ||
/// <returns>Per-fold results: metrics, models, scored datasets.</returns> | ||
public (CalibratedBinaryClassificationMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( | ||
public CrossValidationResult<CalibratedBinaryClassificationMetrics>[] CrossValidate( | ||
IDataView data, IEstimator<ITransformer> estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, | ||
string stratificationColumn = null, uint? seed = null) | ||
{ | ||
Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); | ||
var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed); | ||
return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); | ||
return result.Select(x => new CrossValidationResult<CalibratedBinaryClassificationMetrics>(x.Model, | ||
Evaluate(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray(); | ||
} | ||
} | ||
|
||
|
@@ -369,12 +452,13 @@ public ClusteringMetrics Evaluate(IDataView data, | |
/// If the <paramref name="stratificationColumn"/> is not provided, the random numbers generated to create it, will use this seed as value. | ||
/// And if it is not provided, the default value will be used.</param> | ||
/// <returns>Per-fold results: metrics, models, scored datasets.</returns> | ||
public (ClusteringMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( | ||
public CrossValidationResult<ClusteringMetrics>[] CrossValidate( | ||
IDataView data, IEstimator<ITransformer> estimator, int numFolds = 5, string labelColumn = null, string featuresColumn = null, | ||
string stratificationColumn = null, uint? seed = null) | ||
{ | ||
var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed); | ||
return result.Select(x => (Evaluate(x.scoredTestSet, label: labelColumn, features: featuresColumn), x.model, x.scoredTestSet)).ToArray(); | ||
return result.Select(x => new CrossValidationResult<ClusteringMetrics>(x.Model, | ||
Evaluate(x.Scores, label: labelColumn, features: featuresColumn), x.Scores, x.Fold)).ToArray(); | ||
} | ||
} | ||
|
||
|
@@ -444,13 +528,14 @@ public MultiClassClassifierMetrics Evaluate(IDataView data, string label = Defau | |
/// If the <paramref name="stratificationColumn"/> is not provided, the random numbers generated to create it, will use this seed as value. | ||
/// And if it is not provided, the default value will be used.</param> | ||
/// <returns>Per-fold results: metrics, models, scored datasets.</returns> | ||
public (MultiClassClassifierMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( | ||
public CrossValidationResult<MultiClassClassifierMetrics>[] CrossValidate( | ||
IDataView data, IEstimator<ITransformer> estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, | ||
string stratificationColumn = null, uint? seed = null) | ||
{ | ||
Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); | ||
var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed); | ||
return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); | ||
return result.Select(x => new CrossValidationResult<MultiClassClassifierMetrics>(x.Model, | ||
Evaluate(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray(); | ||
} | ||
} | ||
|
||
|
@@ -511,13 +596,14 @@ public RegressionMetrics Evaluate(IDataView data, string label = DefaultColumnNa | |
/// If the <paramref name="stratificationColumn"/> is not provided, the random numbers generated to create it, will use this seed as value. | ||
/// And if it is not provided, the default value will be used.</param> | ||
/// <returns>Per-fold results: metrics, models, scored datasets.</returns> | ||
public (RegressionMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( | ||
public CrossValidationResult<RegressionMetrics>[] CrossValidate( | ||
IDataView data, IEstimator<ITransformer> estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, | ||
string stratificationColumn = null, uint? seed = null) | ||
{ | ||
Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); | ||
var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed); | ||
return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); | ||
return result.Select(x => new CrossValidationResult<RegressionMetrics>(x.Model, | ||
Evaluate(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray(); | ||
} | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why struct? #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why class?
it's a small collection of immutable objects.
In reply to: 255723990 [](ancestors = 255723990)