Skip to content

TrainTestSplit function #1005

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Sep 26, 2018
6 changes: 6 additions & 0 deletions src/Microsoft.ML.Data/DataLoadSave/CompositeDataReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ public void SaveTo(IHostEnvironment env, Stream outputStream)
/// </summary>
public static class CompositeDataReader
{
/// <summary>
/// Save the contents to a stream, as a "model file".
/// </summary>
public static void SaveTo<TSource>(this IDataReader<TSource> reader, IHostEnvironment env, Stream outputStream)
=> new CompositeDataReader<TSource, ITransformer>(reader).SaveTo(env, outputStream);

/// <summary>
/// Load the pipeline from stream.
/// </summary>
Expand Down
219 changes: 218 additions & 1 deletion src/Microsoft.ML.Data/Training/TrainContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Transforms;
using System;
using System.Collections.Generic;
using System.Linq;

namespace Microsoft.ML.Runtime.Training
namespace Microsoft.ML
{
/// <summary>
/// A training context is an object instantiable by a user to do various tasks relating to a particular
Expand All @@ -16,13 +22,136 @@ public abstract class TrainContextBase
protected readonly IHost Host;
internal IHostEnvironment Environment => Host;

/// <summary>
/// Split the dataset into the train set and test set according to the given fraction.
/// Respects the <paramref name="stratificationColumn"/> if provided.
/// </summary>
/// <param name="data">The dataset to split.</param>
/// <param name="testFraction">The fraction of data to go into the test set.</param>
/// <param name="stratificationColumn">Optional stratification column.</param>
/// <remarks>If two examples share the same value of the <paramref name="stratificationColumn"/> (if provided),
/// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from
/// train to the test set.</remarks>
/// <returns>A pair of datasets, for the train and test set.</returns>
public (IDataView trainSet, IDataView testSet) TrainTestSplit(IDataView data, double testFraction = 0.1, string stratificationColumn = null)
{
Host.CheckValue(data, nameof(data));
Host.CheckParam(0 < testFraction && testFraction < 1, nameof(testFraction), "Must be between 0 and 1 exclusive");
Host.CheckValueOrNull(stratificationColumn);

EnsureStratificationColumn(ref data, ref stratificationColumn);

var trainFilter = new RangeFilter(Host, new RangeFilter.Arguments()
{
Column = stratificationColumn,
Min = 0,
Max = testFraction,
Complement = true
}, data);
var testFilter = new RangeFilter(Host, new RangeFilter.Arguments()
{
Column = stratificationColumn,
Min = 0,
Max = testFraction,
Complement = false
}, data);

return (trainFilter, testFilter);
}

/// <summary>
/// Train the <paramref name="estimator"/> on <paramref name="numFolds"/> folds of the data sequentially.
/// Return each model and each scored test dataset.
/// </summary>
protected (IDataView scoredTestSet, ITransformer model)[] CrossValidateTrain(IDataView data, IEstimator<ITransformer> estimator,
int numFolds, string stratificationColumn)
{
Host.CheckValue(data, nameof(data));
Host.CheckValue(estimator, nameof(estimator));
Host.CheckParam(numFolds > 1, nameof(numFolds), "Must be more than 1");
Host.CheckValueOrNull(stratificationColumn);

EnsureStratificationColumn(ref data, ref stratificationColumn);

Func<int, (IDataView scores, ITransformer model)> foldFunction =
fold =>
{
var trainFilter = new RangeFilter(Host, new RangeFilter.Arguments
{
Column = stratificationColumn,
Min = (double)fold / numFolds,
Max = (double)(fold + 1) / numFolds,
Complement = true
}, data);
var testFilter = new RangeFilter(Host, new RangeFilter.Arguments
{
Column = stratificationColumn,
Min = (double)fold / numFolds,
Max = (double)(fold + 1) / numFolds,
Complement = false
}, data);

var model = estimator.Fit(trainFilter);
var scoredTest = model.Transform(testFilter);
return (scoredTest, model);
};

// Sequential per-fold training.
// REVIEW: we could have a parallel implementation here. We would need to
// spawn off a separate host per fold in that case.
var result = new List<(IDataView scores, ITransformer model)>();
for (int fold = 0; fold < numFolds; fold++)
result.Add(foldFunction(fold));

return result.ToArray();
Copy link
Member

@sfilipi sfilipi Sep 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

result.ToArray() [](start = 19, length = 16)

curious, why not return a list? #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We tend to avoid returning mutable collections.


In reply to: 220371873 [](ancestors = 220371873)

}

protected TrainContextBase(IHostEnvironment env, string registrationName)
{
Contracts.CheckValue(env, nameof(env));
env.CheckNonEmpty(registrationName, nameof(registrationName));
Host = env.Register(registrationName);
}

/// <summary>
/// Make sure the provided <paramref name="stratificationColumn"/> is valid
/// for <see cref="RangeFilter"/>, hash it if needed, or introduce a new one
/// if needed.
/// </summary>
private void EnsureStratificationColumn(ref IDataView data, ref string stratificationColumn)
{
// We need to handle two cases: if the stratification column is provided, we use hashJoin to
// build a single hash of it. If it is not, we generate a random number.

if (stratificationColumn == null)
{
stratificationColumn = data.Schema.GetTempColumnName("StratificationColumn");
Copy link
Member

@sfilipi sfilipi Sep 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StratificationColumn [](start = 70, length = 20)

is there a DefaultColumnName for this? #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked, there is none. I don't think there should be either, it's a pretty peculiar concept, unique to CV and split scenarios


In reply to: 220372046 [](ancestors = 220372046)

data = new GenerateNumberTransform(Host, data, stratificationColumn);
}
else
{
if (!data.Schema.TryGetColumnIndex(stratificationColumn, out int stratCol))
throw Host.ExceptSchemaMismatch(nameof(stratificationColumn), "stratification", stratificationColumn);

var type = data.Schema.GetColumnType(stratCol);
if (!RangeFilter.IsValidRangeFilterColumnType(Host, type))
{
// Hash the stratification column.
// REVIEW: this could currently crash, since Hash only accepts a limited set
// of column types. It used to be HashJoin, but we should probably extend Hash
// instead of having two hash transformations.
var origStratCol = stratificationColumn;
int tmp;
int inc = 0;

// Generate a new column with the hashed stratification column.
while (data.Schema.TryGetColumnIndex(stratificationColumn, out tmp))
stratificationColumn = string.Format("{0}_{1:000}", origStratCol, ++inc);
data = new HashEstimator(Host, origStratCol, stratificationColumn, 30).Fit(data).Transform(data);
}
}
}

/// <summary>
/// Subclasses of <see cref="TrainContext"/> will provide little "extension method" hookable objects
/// (e.g., something like <see cref="BinaryClassificationContext.Trainers"/>). User code will only
Expand Down Expand Up @@ -140,6 +269,50 @@ public BinaryClassifierEvaluator.Result EvaluateNonCalibrated(IDataView data, st
var eval = new BinaryClassifierEvaluator(Host, new BinaryClassifierEvaluator.Arguments() { });
return eval.Evaluate(data, label, score, predictedLabel);
}

/// <summary>
/// Run cross-validation over <paramref name="numFolds"/> folds of <paramref name="data"/>, by fitting <paramref name="estimator"/>,
/// and respecting <paramref name="stratificationColumn"/> if provided.
/// Then evaluate each sub-model against <paramref name="labelColumn"/> and return metrics.
/// </summary>
/// <param name="data">The data to run cross-validation on.</param>
/// <param name="estimator">The estimator to fit.</param>
/// <param name="numFolds">Number of cross-validation folds.</param>
/// <param name="labelColumn">The label column (for evaluation).</param>
/// <param name="stratificationColumn">Optional stratification column.</param>
/// <remarks>If two examples share the same value of the <paramref name="stratificationColumn"/> (if provided),
/// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from
/// train to the test set.</remarks>
/// <returns>Per-fold results: metrics, models, scored datasets.</returns>
public (BinaryClassifierEvaluator.Result metrics, ITransformer model, IDataView scoredTestData)[] CrossValidateNonCalibrated(
IDataView data, IEstimator<ITransformer> estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null)
{
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn);
return result.Select(x => (EvaluateNonCalibrated(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray();
}

/// <summary>
/// Run cross-validation over <paramref name="numFolds"/> folds of <paramref name="data"/>, by fitting <paramref name="estimator"/>,
/// and respecting <paramref name="stratificationColumn"/> if provided.
/// Then evaluate each sub-model against <paramref name="labelColumn"/> and return metrics.
/// </summary>
/// <param name="data">The data to run cross-validation on.</param>
/// <param name="estimator">The estimator to fit.</param>
/// <param name="numFolds">Number of cross-validation folds.</param>
/// <param name="labelColumn">The label column (for evaluation).</param>
/// <param name="stratificationColumn">Optional stratification column.</param>
/// <remarks>If two examples share the same value of the <paramref name="stratificationColumn"/> (if provided),
/// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from
/// train to the test set.</remarks>
/// <returns>Per-fold results: metrics, models, scored datasets.</returns>
public (BinaryClassifierEvaluator.CalibratedResult metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate(
IDataView data, IEstimator<ITransformer> estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null)
{
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn);
return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray();
}
}

/// <summary>
Expand Down Expand Up @@ -191,6 +364,28 @@ public MultiClassClassifierEvaluator.Result Evaluate(IDataView data, string labe
var eval = new MultiClassClassifierEvaluator(Host, args);
return eval.Evaluate(data, label, score, predictedLabel);
}

/// <summary>
/// Run cross-validation over <paramref name="numFolds"/> folds of <paramref name="data"/>, by fitting <paramref name="estimator"/>,
/// and respecting <paramref name="stratificationColumn"/> if provided.
/// Then evaluate each sub-model against <paramref name="labelColumn"/> and return metrics.
/// </summary>
/// <param name="data">The data to run cross-validation on.</param>
/// <param name="estimator">The estimator to fit.</param>
/// <param name="numFolds">Number of cross-validation folds.</param>
/// <param name="labelColumn">The label column (for evaluation).</param>
/// <param name="stratificationColumn">Optional stratification column.</param>
/// <remarks>If two examples share the same value of the <paramref name="stratificationColumn"/> (if provided),
/// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from
/// train to the test set.</remarks>
/// <returns>Per-fold results: metrics, models, scored datasets.</returns>
public (MultiClassClassifierEvaluator.Result metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate(
IDataView data, IEstimator<ITransformer> estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null)
{
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn);
return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray();
}
}

/// <summary>
Expand Down Expand Up @@ -233,5 +428,27 @@ public RegressionEvaluator.Result Evaluate(IDataView data, string label, string
var eval = new RegressionEvaluator(Host, new RegressionEvaluator.Arguments() { });
return eval.Evaluate(data, label, score);
}

/// <summary>
/// Run cross-validation over <paramref name="numFolds"/> folds of <paramref name="data"/>, by fitting <paramref name="estimator"/>,
/// and respecting <paramref name="stratificationColumn"/> if provided.
/// Then evaluate each sub-model against <paramref name="labelColumn"/> and return metrics.
/// </summary>
/// <param name="data">The data to run cross-validation on.</param>
/// <param name="estimator">The estimator to fit.</param>
/// <param name="numFolds">Number of cross-validation folds.</param>
/// <param name="labelColumn">The label column (for evaluation).</param>
/// <param name="stratificationColumn">Optional stratification column.</param>
/// <remarks>If two examples share the same value of the <paramref name="stratificationColumn"/> (if provided),
/// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from
/// train to the test set.</remarks>
/// <returns>Per-fold results: metrics, models, scored datasets.</returns>
public (RegressionEvaluator.Result metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate(
IDataView data, IEstimator<ITransformer> estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null)
{
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn);
return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray();
}
}
}
Loading