Skip to content

Get rid of value tuples in TrainTest and CrossValidation #2507

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/code/MlNetCookBook.md
Original file line number Diff line number Diff line change
Expand Up @@ -825,20 +825,20 @@ var pipeline =
.Append(mlContext.MulticlassClassification.Trainers.StochasticDualCoordinateAscent());

// Split the data 90:10 into train and test sets, train and evaluate.
var (trainData, testData) = mlContext.MulticlassClassification.TrainTestSplit(data, testFraction: 0.1);
var split = mlContext.MulticlassClassification.TrainTestSplit(data, testFraction: 0.1);

// Train the model.
var model = pipeline.Fit(trainData);
var model = pipeline.Fit(split.TrainSet);
// Compute quality metrics on the test set.
var metrics = mlContext.MulticlassClassification.Evaluate(model.Transform(testData));
var metrics = mlContext.MulticlassClassification.Evaluate(model.Transform(split.TestSet));
Console.WriteLine(metrics.AccuracyMicro);

// Now run the 5-fold cross-validation experiment, using the same pipeline.
var cvResults = mlContext.MulticlassClassification.CrossValidate(data, pipeline, numFolds: 5);

// The results object is an array of 5 elements. For each of the 5 folds, we have metrics, model and scored test data.
// Let's compute the average micro-accuracy.
var microAccuracies = cvResults.Select(r => r.metrics.AccuracyMicro);
var microAccuracies = cvResults.Select(r => r.Metrics.AccuracyMicro);
Console.WriteLine(microAccuracies.Average());

```
Expand Down
6 changes: 3 additions & 3 deletions docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public static void Calibration()
var data = reader.Read(dataFile);

// Split the dataset into two parts: one used for training, the other to train the calibrator
var (trainData, calibratorTrainingData) = mlContext.BinaryClassification.TrainTestSplit(data, testFraction: 0.1);
var split = mlContext.BinaryClassification.TrainTestSplit(data, testFraction: 0.1);

// Featurize the text column through the FeaturizeText API.
// Then append the StochasticDualCoordinateAscentBinary binary classifier, setting the "Label" column as the label of the dataset, and
Expand All @@ -56,12 +56,12 @@ public static void Calibration()
loss: new HingeLoss())); // By specifying loss: new HingeLoss(), StochasticDualCoordinateAscent will train a support vector machine (SVM).

// Fit the pipeline, and get a transformer that knows how to score new data.
var transformer = pipeline.Fit(trainData);
var transformer = pipeline.Fit(split.TrainSet);
IPredictor model = transformer.LastTransformer.Model;

// Let's score the new data. The score will give us a numerical estimation of the chance that the particular sample
// bears positive sentiment. This estimate is relative to the numbers obtained.
var scoredData = transformer.Transform(calibratorTrainingData);
var scoredData = transformer.Transform(split.TestSet);
var scoredDataPreview = scoredData.Preview();

PrintRowViewValues(scoredDataPreview);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public static void LogisticRegression()

IDataView data = reader.Read(dataFilePath);

var (trainData, testData) = ml.BinaryClassification.TrainTestSplit(data, testFraction: 0.2);
var split = ml.BinaryClassification.TrainTestSplit(data, testFraction: 0.2);

var pipeline = ml.Transforms.Concatenate("Text", "workclass", "education", "marital-status",
"relationship", "ethnicity", "sex", "native-country")
Expand All @@ -66,9 +66,9 @@ public static void LogisticRegression()
"education-num", "capital-gain", "capital-loss", "hours-per-week"))
.Append(ml.BinaryClassification.Trainers.LogisticRegression());

var model = pipeline.Fit(trainData);
var model = pipeline.Fit(split.TrainSet);

var dataWithPredictions = model.Transform(testData);
var dataWithPredictions = model.Transform(split.TestSet);

var metrics = ml.BinaryClassification.Evaluate(dataWithPredictions);

Expand Down
4 changes: 2 additions & 2 deletions docs/samples/Microsoft.ML.Samples/Dynamic/SDCA.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public static void SDCA_BinaryClassification()
// Step 3: Run Cross-Validation on this pipeline.
var cvResults = mlContext.BinaryClassification.CrossValidate(data, pipeline, labelColumn: "Sentiment");

var accuracies = cvResults.Select(r => r.metrics.Accuracy);
var accuracies = cvResults.Select(r => r.Metrics.Accuracy);
Console.WriteLine(accuracies.Average());

// If we wanted to specify more advanced parameters for the algorithm,
Expand All @@ -70,7 +70,7 @@ public static void SDCA_BinaryClassification()

// Run Cross-Validation on this second pipeline.
var cvResults_advancedPipeline = mlContext.BinaryClassification.CrossValidate(data, pipeline, labelColumn: "Sentiment", numFolds: 3);
accuracies = cvResults_advancedPipeline.Select(r => r.metrics.Accuracy);
accuracies = cvResults_advancedPipeline.Select(r => r.Metrics.Accuracy);
Console.WriteLine(accuracies.Average());

}
Expand Down
124 changes: 105 additions & 19 deletions src/Microsoft.ML.Data/TrainCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,31 @@ public abstract class TrainCatalogBase
[BestFriend]
internal IHostEnvironment Environment => Host;

/// <summary>
/// A pair of datasets, for the train and test set.
/// </summary>
public struct TrainTestData
{
Copy link
Contributor

@rogancarr rogancarr Feb 11, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why struct? #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why class?
it's a small collection of immutable objects.


In reply to: 255723990 [](ancestors = 255723990)

Copy link
Contributor

@rogancarr rogancarr Feb 11, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could see us wanting to expand this in the future to also include a validation set, so it could be prudent to keep the name a bit vague. PartitionedData? #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would be summary for PartitionedData?
This one used as output for TrainTestSplit function, if we ever introduce TrainTestValidationSplit I would prefer to create another object for that.


In reply to: 255725077 [](ancestors = 255725077)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point!


In reply to: 255725959 [](ancestors = 255725959,255725077)

/// <summary>
/// Training set.
/// </summary>
public readonly IDataView TrainSet;
/// <summary>
/// Testing set.
/// </summary>
public readonly IDataView TestSet;
/// <summary>
/// Create pair of datasets.
/// </summary>
/// <param name="trainSet">Training set.</param>
/// <param name="testSet">Testing set.</param>
internal TrainTestData(IDataView trainSet, IDataView testSet)
{
TrainSet = trainSet;
TestSet = testSet;
}
}

/// <summary>
/// Split the dataset into the train set and test set according to the given fraction.
/// Respects the <paramref name="stratificationColumn"/> if provided.
Expand All @@ -37,8 +62,7 @@ public abstract class TrainCatalogBase
/// <param name="seed">Optional parameter used in combination with the <paramref name="stratificationColumn"/>.
/// If the <paramref name="stratificationColumn"/> is not provided, the random numbers generated to create it, will use this seed as value.
/// And if it is not provided, the default value will be used.</param>
/// <returns>A pair of datasets, for the train and test set.</returns>
public (IDataView trainSet, IDataView testSet) TrainTestSplit(IDataView data, double testFraction = 0.1, string stratificationColumn = null, uint? seed = null)
public TrainTestData TrainTestSplit(IDataView data, double testFraction = 0.1, string stratificationColumn = null, uint? seed = null)
{
Host.CheckValue(data, nameof(data));
Host.CheckParam(0 < testFraction && testFraction < 1, nameof(testFraction), "Must be between 0 and 1 exclusive");
Expand All @@ -61,14 +85,71 @@ public abstract class TrainCatalogBase
Complement = false
}, data);

return (trainFilter, testFilter);
return new TrainTestData(trainFilter, testFilter);
}

/// <summary>
/// Results for specific cross-validation fold.
/// </summary>
protected internal struct CrossValidationResult
{
/// <summary>
/// Model trained during cross validation fold.
/// </summary>
public readonly ITransformer Model;
/// <summary>
/// Scored test set with <see cref="Model"/> for this fold.
/// </summary>
public readonly IDataView Scores;
/// <summary>
/// Fold number.
/// </summary>
public readonly int Fold;

public CrossValidationResult(ITransformer model, IDataView scores, int fold)
{
Model = model;
Scores = scores;
Fold = fold;
}
}
/// <summary>
/// Results of running cross-validation.
/// </summary>
/// <typeparam name="T">Type of metric class.</typeparam>
public sealed class CrossValidationResult<T> where T : class
{
/// <summary>
/// Metrics for this cross-validation fold.
/// </summary>
public readonly T Metrics;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

T [](start = 28, length = 1)

Part of me want to create interface and mark with it all Metric classes. Another part of me remembers what we try get rid of all empty interfaces.
Maybe empty base class?
Would like to hear your comments as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for some way of making this not generic.


In reply to: 255714622 [](ancestors = 255714622)

/// <summary>
/// Model trained during cross-validation fold.
/// </summary>
public readonly ITransformer Model;
/// <summary>
/// The scored hold-out set for this fold.
/// </summary>
public readonly IDataView ScoredHoldOutSet;
/// <summary>
/// Fold number.
/// </summary>
public readonly int Fold;
Copy link
Contributor

@rogancarr rogancarr Feb 11, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

public readonly int Fold [](start = 12, length = 24)

Is this necessary, since they will be returned in an (ordered) array? And it'll be confusing if the order of the array doesn't match the fold number. #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary ?- no, is it nice to have? - I think so. As soon as you tear them away from array, and let's say you want to report specific folds and do some stuff on them, you need to pass around fold anyway.

it'll be confusing if the order of the array doesn't match the fold number.
Not the case right now, but I would prefer to know real fold than index in array.


In reply to: 255727531 [](ancestors = 255727531)


internal CrossValidationResult(ITransformer model, T metrics, IDataView scores, int fold)
{
Model = model;
Metrics = metrics;
ScoredHoldOutSet = scores;
Fold = fold;
}
}

/// <summary>
/// Train the <paramref name="estimator"/> on <paramref name="numFolds"/> folds of the data sequentially.
/// Return each model and each scored test dataset.
/// </summary>
protected internal (IDataView scoredTestSet, ITransformer model)[] CrossValidateTrain(IDataView data, IEstimator<ITransformer> estimator,
protected internal CrossValidationResult[] CrossValidateTrain(IDataView data, IEstimator<ITransformer> estimator,
int numFolds, string stratificationColumn, uint? seed = null)
{
Host.CheckValue(data, nameof(data));
Expand All @@ -78,7 +159,7 @@ protected internal (IDataView scoredTestSet, ITransformer model)[] CrossValidate

EnsureStratificationColumn(ref data, ref stratificationColumn, seed);

Func<int, (IDataView scores, ITransformer model)> foldFunction =
Func<int, CrossValidationResult> foldFunction =
fold =>
{
var trainFilter = new RangeFilter(Host, new RangeFilter.Options
Expand All @@ -98,17 +179,17 @@ protected internal (IDataView scoredTestSet, ITransformer model)[] CrossValidate

var model = estimator.Fit(trainFilter);
var scoredTest = model.Transform(testFilter);
return (scoredTest, model);
return new CrossValidationResult(model, scoredTest, fold);
};

// Sequential per-fold training.
// REVIEW: we could have a parallel implementation here. We would need to
// spawn off a separate host per fold in that case.
var result = new List<(IDataView scores, ITransformer model)>();
var result = new CrossValidationResult[numFolds];
for (int fold = 0; fold < numFolds; fold++)
result.Add(foldFunction(fold));
result[fold] = foldFunction(fold);

return result.ToArray();
return result;
}

protected internal TrainCatalogBase(IHostEnvironment env, string registrationName)
Expand Down Expand Up @@ -263,13 +344,14 @@ public BinaryClassificationMetrics EvaluateNonCalibrated(IDataView data, string
/// If the <paramref name="stratificationColumn"/> is not provided, the random numbers generated to create it, will use this seed as value.
/// And if it is not provided, the default value will be used.</param>
/// <returns>Per-fold results: metrics, models, scored datasets.</returns>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per-fold results [](start = 21, length = 16)

maybe<see cref="CrossValidationResult"

public (BinaryClassificationMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidateNonCalibrated(
public CrossValidationResult<BinaryClassificationMetrics>[] CrossValidateNonCalibrated(
IDataView data, IEstimator<ITransformer> estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label,
string stratificationColumn = null, uint? seed = null)
{
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed);
return result.Select(x => (EvaluateNonCalibrated(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray();
return result.Select(x => new CrossValidationResult<BinaryClassificationMetrics>(x.Model,
EvaluateNonCalibrated(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray();
}

/// <summary>
Expand All @@ -287,13 +369,14 @@ public BinaryClassificationMetrics EvaluateNonCalibrated(IDataView data, string
/// train to the test set.</remarks>
/// <param name="seed">If <paramref name="stratificationColumn"/> not present in dataset we will generate random filled column based on provided <paramref name="seed"/>.</param>
/// <returns>Per-fold results: metrics, models, scored datasets.</returns>
public (CalibratedBinaryClassificationMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate(
public CrossValidationResult<CalibratedBinaryClassificationMetrics>[] CrossValidate(
IDataView data, IEstimator<ITransformer> estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label,
string stratificationColumn = null, uint? seed = null)
{
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed);
return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray();
return result.Select(x => new CrossValidationResult<CalibratedBinaryClassificationMetrics>(x.Model,
Evaluate(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray();
}
}

Expand Down Expand Up @@ -369,12 +452,13 @@ public ClusteringMetrics Evaluate(IDataView data,
/// If the <paramref name="stratificationColumn"/> is not provided, the random numbers generated to create it, will use this seed as value.
/// And if it is not provided, the default value will be used.</param>
/// <returns>Per-fold results: metrics, models, scored datasets.</returns>
public (ClusteringMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate(
public CrossValidationResult<ClusteringMetrics>[] CrossValidate(
IDataView data, IEstimator<ITransformer> estimator, int numFolds = 5, string labelColumn = null, string featuresColumn = null,
string stratificationColumn = null, uint? seed = null)
{
var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed);
return result.Select(x => (Evaluate(x.scoredTestSet, label: labelColumn, features: featuresColumn), x.model, x.scoredTestSet)).ToArray();
return result.Select(x => new CrossValidationResult<ClusteringMetrics>(x.Model,
Evaluate(x.Scores, label: labelColumn, features: featuresColumn), x.Scores, x.Fold)).ToArray();
}
}

Expand Down Expand Up @@ -444,13 +528,14 @@ public MultiClassClassifierMetrics Evaluate(IDataView data, string label = Defau
/// If the <paramref name="stratificationColumn"/> is not provided, the random numbers generated to create it, will use this seed as value.
/// And if it is not provided, the default value will be used.</param>
/// <returns>Per-fold results: metrics, models, scored datasets.</returns>
public (MultiClassClassifierMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate(
public CrossValidationResult<MultiClassClassifierMetrics>[] CrossValidate(
IDataView data, IEstimator<ITransformer> estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label,
string stratificationColumn = null, uint? seed = null)
{
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed);
return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray();
return result.Select(x => new CrossValidationResult<MultiClassClassifierMetrics>(x.Model,
Evaluate(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray();
}
}

Expand Down Expand Up @@ -511,13 +596,14 @@ public RegressionMetrics Evaluate(IDataView data, string label = DefaultColumnNa
/// If the <paramref name="stratificationColumn"/> is not provided, the random numbers generated to create it, will use this seed as value.
/// And if it is not provided, the default value will be used.</param>
/// <returns>Per-fold results: metrics, models, scored datasets.</returns>
public (RegressionMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate(
public CrossValidationResult<RegressionMetrics>[] CrossValidate(
IDataView data, IEstimator<ITransformer> estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label,
string stratificationColumn = null, uint? seed = null)
{
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed);
return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray();
return result.Select(x => new CrossValidationResult<RegressionMetrics>(x.Model,
Evaluate(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray();
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Recommender/RecommenderCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,13 @@ public RegressionMetrics Evaluate(IDataView data, string label = DefaultColumnNa
/// If the <paramref name="stratificationColumn"/> is not provided, the random numbers generated to create it, will use this seed as value.
/// And if it is not provided, the default value will be used.</param>
/// <returns>Per-fold results: metrics, models, scored datasets.</returns>
public (RegressionMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate(
public CrossValidationResult<RegressionMetrics>[] CrossValidate(
IDataView data, IEstimator<ITransformer> estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label,
string stratificationColumn = null, uint? seed = null)
{
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed);
return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray();
return result.Select(x => new CrossValidationResult<RegressionMetrics>(x.Model, Evaluate(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray();
}
}
}
Loading