From 3d9b174e3b3335cc5a126bb59a6f5c6b03418620 Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Mon, 11 Feb 2019 14:05:56 -0800 Subject: [PATCH 1/3] Get rid of value tuple in TrainTest and CrossValidation --- docs/code/MlNetCookBook.md | 8 +- .../Dynamic/Calibrator.cs | 6 +- .../Dynamic/LogisticRegression.cs | 6 +- .../Microsoft.ML.Samples/Dynamic/SDCA.cs | 4 +- src/Microsoft.ML.Data/TrainCatalog.cs | 118 +++++++++++++++--- .../RecommenderCatalog.cs | 4 +- .../TrainingStaticExtensions.cs | 22 ++-- .../CookbookSamplesDynamicApi.cs | 8 +- .../Scenarios/Api/TestApi.cs | 20 +-- 9 files changed, 139 insertions(+), 57 deletions(-) diff --git a/docs/code/MlNetCookBook.md b/docs/code/MlNetCookBook.md index 1551e275d3..bb5e866c04 100644 --- a/docs/code/MlNetCookBook.md +++ b/docs/code/MlNetCookBook.md @@ -825,12 +825,12 @@ var pipeline = .Append(mlContext.MulticlassClassification.Trainers.StochasticDualCoordinateAscent()); // Split the data 90:10 into train and test sets, train and evaluate. -var (trainData, testData) = mlContext.MulticlassClassification.TrainTestSplit(data, testFraction: 0.1); +var split = mlContext.MulticlassClassification.TrainTestSplit(data, testFraction: 0.1); // Train the model. -var model = pipeline.Fit(trainData); +var model = pipeline.Fit(split.TrainSet); // Compute quality metrics on the test set. -var metrics = mlContext.MulticlassClassification.Evaluate(model.Transform(testData)); +var metrics = mlContext.MulticlassClassification.Evaluate(model.Transform(split.TestSet)); Console.WriteLine(metrics.AccuracyMicro); // Now run the 5-fold cross-validation experiment, using the same pipeline. @@ -838,7 +838,7 @@ var cvResults = mlContext.MulticlassClassification.CrossValidate(data, pipeline, // The results object is an array of 5 elements. For each of the 5 folds, we have metrics, model and scored test data. // Let's compute the average micro-accuracy. -var microAccuracies = cvResults.Select(r => r.metrics.AccuracyMicro); +var microAccuracies = cvResults.Select(r => r.Metrics.AccuracyMicro); Console.WriteLine(microAccuracies.Average()); ``` diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs index eb46a9683a..ed75040e77 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs @@ -43,7 +43,7 @@ public static void Calibration() var data = reader.Read(dataFile); // Split the dataset into two parts: one used for training, the other to train the calibrator - var (trainData, calibratorTrainingData) = mlContext.BinaryClassification.TrainTestSplit(data, testFraction: 0.1); + var split = mlContext.BinaryClassification.TrainTestSplit(data, testFraction: 0.1); // Featurize the text column through the FeaturizeText API. // Then append the StochasticDualCoordinateAscentBinary binary classifier, setting the "Label" column as the label of the dataset, and @@ -56,12 +56,12 @@ public static void Calibration() loss: new HingeLoss())); // By specifying loss: new HingeLoss(), StochasticDualCoordinateAscent will train a support vector machine (SVM). // Fit the pipeline, and get a transformer that knows how to score new data. - var transformer = pipeline.Fit(trainData); + var transformer = pipeline.Fit(split.TrainSet); IPredictor model = transformer.LastTransformer.Model; // Let's score the new data. The score will give us a numerical estimation of the chance that the particular sample // bears positive sentiment. This estimate is relative to the numbers obtained. - var scoredData = transformer.Transform(calibratorTrainingData); + var scoredData = transformer.Transform(split.TestSet); var scoredDataPreview = scoredData.Preview(); PrintRowViewValues(scoredDataPreview); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/LogisticRegression.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/LogisticRegression.cs index 4b25bdc805..f85b68bb7e 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/LogisticRegression.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/LogisticRegression.cs @@ -57,7 +57,7 @@ public static void LogisticRegression() IDataView data = reader.Read(dataFilePath); - var (trainData, testData) = ml.BinaryClassification.TrainTestSplit(data, testFraction: 0.2); + var split = ml.BinaryClassification.TrainTestSplit(data, testFraction: 0.2); var pipeline = ml.Transforms.Concatenate("Text", "workclass", "education", "marital-status", "relationship", "ethnicity", "sex", "native-country") @@ -66,9 +66,9 @@ public static void LogisticRegression() "education-num", "capital-gain", "capital-loss", "hours-per-week")) .Append(ml.BinaryClassification.Trainers.LogisticRegression()); - var model = pipeline.Fit(trainData); + var model = pipeline.Fit(split.TrainSet); - var dataWithPredictions = model.Transform(testData); + var dataWithPredictions = model.Transform(split.TestSet); var metrics = ml.BinaryClassification.Evaluate(dataWithPredictions); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/SDCA.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/SDCA.cs index 341827bde3..0f06b9c905 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/SDCA.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/SDCA.cs @@ -54,7 +54,7 @@ public static void SDCA_BinaryClassification() // Step 3: Run Cross-Validation on this pipeline. var cvResults = mlContext.BinaryClassification.CrossValidate(data, pipeline, labelColumn: "Sentiment"); - var accuracies = cvResults.Select(r => r.metrics.Accuracy); + var accuracies = cvResults.Select(r => r.Metrics.Accuracy); Console.WriteLine(accuracies.Average()); // If we wanted to specify more advanced parameters for the algorithm, @@ -70,7 +70,7 @@ public static void SDCA_BinaryClassification() // Run Cross-Validation on this second pipeline. var cvResults_advancedPipeline = mlContext.BinaryClassification.CrossValidate(data, pipeline, labelColumn: "Sentiment", numFolds: 3); - accuracies = cvResults_advancedPipeline.Select(r => r.metrics.Accuracy); + accuracies = cvResults_advancedPipeline.Select(r => r.Metrics.Accuracy); Console.WriteLine(accuracies.Average()); } diff --git a/src/Microsoft.ML.Data/TrainCatalog.cs b/src/Microsoft.ML.Data/TrainCatalog.cs index d8b4c53265..0dd62dfc25 100644 --- a/src/Microsoft.ML.Data/TrainCatalog.cs +++ b/src/Microsoft.ML.Data/TrainCatalog.cs @@ -25,6 +25,31 @@ public abstract class TrainCatalogBase [BestFriend] internal IHostEnvironment Environment => Host; + /// + /// A pair of datasets, for the train and test set. + /// + public struct TrainTestData + { + /// + /// Training set. + /// + public readonly IDataView TrainSet; + /// + /// Testing set. + /// + public readonly IDataView TestSet; + /// + /// Create pair of datasets. + /// + /// Training set. + /// Testing set. + public TrainTestData(IDataView trainSet, IDataView testSet) + { + TrainSet = trainSet; + TestSet = testSet; + } + } + /// /// Split the dataset into the train set and test set according to the given fraction. /// Respects the if provided. @@ -37,8 +62,7 @@ public abstract class TrainCatalogBase /// Optional parameter used in combination with the . /// If the is not provided, the random numbers generated to create it, will use this seed as value. /// And if it is not provided, the default value will be used. - /// A pair of datasets, for the train and test set. - public (IDataView trainSet, IDataView testSet) TrainTestSplit(IDataView data, double testFraction = 0.1, string stratificationColumn = null, uint? seed = null) + public TrainTestData TrainTestSplit(IDataView data, double testFraction = 0.1, string stratificationColumn = null, uint? seed = null) { Host.CheckValue(data, nameof(data)); Host.CheckParam(0 < testFraction && testFraction < 1, nameof(testFraction), "Must be between 0 and 1 exclusive"); @@ -61,14 +85,68 @@ public abstract class TrainCatalogBase Complement = false }, data); - return (trainFilter, testFilter); + return new TrainTestData(trainFilter, testFilter); + } + + /// + /// Results for specific cross validation fold. + /// + protected internal struct CrossValidationResult + { + /// + /// Model trained during cross validation fold. + /// + public readonly ITransformer Model; + /// + /// for scored fold. + /// + public readonly IDataView Scores; + /// + /// Fold number. + /// + public readonly int Fold; + + public CrossValidationResult(ITransformer model, IDataView scores, int fold) + { + Model = model; + Scores = scores; + Fold = fold; + } + } + + public class CrossValidationResult where T : class + { + /// + /// Metrics for cross validation fold. + /// + public readonly T Metrics; + /// + /// Model trained during cross validation fold. + /// + public readonly ITransformer Model; + /// + /// for scored fold. + /// + public readonly IDataView Scores; + /// + /// Fold number. + /// + public readonly int Fold; + + public CrossValidationResult(ITransformer model, T metrics, IDataView scores, int fold) + { + Model = model; + Metrics = metrics; + Scores = scores; + Fold = fold; + } } /// /// Train the on folds of the data sequentially. /// Return each model and each scored test dataset. /// - protected internal (IDataView scoredTestSet, ITransformer model)[] CrossValidateTrain(IDataView data, IEstimator estimator, + protected internal CrossValidationResult[] CrossValidateTrain(IDataView data, IEstimator estimator, int numFolds, string stratificationColumn, uint? seed = null) { Host.CheckValue(data, nameof(data)); @@ -78,7 +156,7 @@ protected internal (IDataView scoredTestSet, ITransformer model)[] CrossValidate EnsureStratificationColumn(ref data, ref stratificationColumn, seed); - Func foldFunction = + Func foldFunction = fold => { var trainFilter = new RangeFilter(Host, new RangeFilter.Options @@ -98,17 +176,17 @@ protected internal (IDataView scoredTestSet, ITransformer model)[] CrossValidate var model = estimator.Fit(trainFilter); var scoredTest = model.Transform(testFilter); - return (scoredTest, model); + return new CrossValidationResult(model, scoredTest, fold); }; // Sequential per-fold training. // REVIEW: we could have a parallel implementation here. We would need to // spawn off a separate host per fold in that case. - var result = new List<(IDataView scores, ITransformer model)>(); + var result = new CrossValidationResult[numFolds]; for (int fold = 0; fold < numFolds; fold++) - result.Add(foldFunction(fold)); + result[fold] = foldFunction(fold); - return result.ToArray(); + return result; } protected internal TrainCatalogBase(IHostEnvironment env, string registrationName) @@ -269,7 +347,7 @@ public BinaryClassificationMetrics EvaluateNonCalibrated(IDataView data, string { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed); - return result.Select(x => (EvaluateNonCalibrated(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); + return result.Select(x => (EvaluateNonCalibrated(x.Scores, labelColumn), x.Model, x.Scores)).ToArray(); } /// @@ -287,13 +365,14 @@ public BinaryClassificationMetrics EvaluateNonCalibrated(IDataView data, string /// train to the test set. /// If not present in dataset we will generate random filled column based on provided . /// Per-fold results: metrics, models, scored datasets. - public (CalibratedBinaryClassificationMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( + public CrossValidationResult[] CrossValidate( IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null, uint? seed = null) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed); - return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); + return result.Select(x => new CrossValidationResult(x.Model, + Evaluate(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray(); } } @@ -369,12 +448,13 @@ public ClusteringMetrics Evaluate(IDataView data, /// If the is not provided, the random numbers generated to create it, will use this seed as value. /// And if it is not provided, the default value will be used. /// Per-fold results: metrics, models, scored datasets. - public (ClusteringMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( + public CrossValidationResult[] CrossValidate( IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = null, string featuresColumn = null, string stratificationColumn = null, uint? seed = null) { var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed); - return result.Select(x => (Evaluate(x.scoredTestSet, label: labelColumn, features: featuresColumn), x.model, x.scoredTestSet)).ToArray(); + return result.Select(x => new CrossValidationResult(x.Model, + Evaluate(x.Scores, label: labelColumn, features: featuresColumn), x.Scores, x.Fold)).ToArray(); } } @@ -444,13 +524,14 @@ public MultiClassClassifierMetrics Evaluate(IDataView data, string label = Defau /// If the is not provided, the random numbers generated to create it, will use this seed as value. /// And if it is not provided, the default value will be used. /// Per-fold results: metrics, models, scored datasets. - public (MultiClassClassifierMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( + public CrossValidationResult[] CrossValidate( IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null, uint? seed = null) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed); - return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); + return result.Select(x => new CrossValidationResult(x.Model, + Evaluate(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray(); } } @@ -511,13 +592,14 @@ public RegressionMetrics Evaluate(IDataView data, string label = DefaultColumnNa /// If the is not provided, the random numbers generated to create it, will use this seed as value. /// And if it is not provided, the default value will be used. /// Per-fold results: metrics, models, scored datasets. - public (RegressionMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( + public CrossValidationResult[] CrossValidate( IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null, uint? seed = null) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed); - return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); + return result.Select(x => new CrossValidationResult(x.Model, + Evaluate(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray(); } } diff --git a/src/Microsoft.ML.Recommender/RecommenderCatalog.cs b/src/Microsoft.ML.Recommender/RecommenderCatalog.cs index 4c98db5a74..637e4962ea 100644 --- a/src/Microsoft.ML.Recommender/RecommenderCatalog.cs +++ b/src/Microsoft.ML.Recommender/RecommenderCatalog.cs @@ -128,13 +128,13 @@ public RegressionMetrics Evaluate(IDataView data, string label = DefaultColumnNa /// If the is not provided, the random numbers generated to create it, will use this seed as value. /// And if it is not provided, the default value will be used. /// Per-fold results: metrics, models, scored datasets. - public (RegressionMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( + public CrossValidationResult[] CrossValidate( IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null, uint? seed = null) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed); - return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); + return result.Select(x => new CrossValidationResult(x.Model, Evaluate(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray(); } } } diff --git a/src/Microsoft.ML.StaticPipe/TrainingStaticExtensions.cs b/src/Microsoft.ML.StaticPipe/TrainingStaticExtensions.cs index 357246a658..34ab6a03ac 100644 --- a/src/Microsoft.ML.StaticPipe/TrainingStaticExtensions.cs +++ b/src/Microsoft.ML.StaticPipe/TrainingStaticExtensions.cs @@ -49,8 +49,8 @@ public static (DataView trainSet, DataView testSet) TrainTestSplit(this stratName = indexer.Get(column); } - var (trainData, testData) = catalog.TrainTestSplit(data.AsDynamic, testFraction, stratName, seed); - return (new DataView(env, trainData, data.Shape), new DataView(env, testData, data.Shape)); + var split = catalog.TrainTestSplit(data.AsDynamic, testFraction, stratName, seed); + return (new DataView(env, split.TrainSet, data.Shape), new DataView(env, split.TestSet, data.Shape)); } /// @@ -105,9 +105,9 @@ public static (RegressionMetrics metrics, Transformer ( - x.metrics, - new Transformer(env, (TTransformer)x.model, data.Shape, estimator.Shape), - new DataView(env, x.scoredTestData, estimator.Shape))) + x.Metrics, + new Transformer(env, (TTransformer)x.Model, data.Shape, estimator.Shape), + new DataView(env, x.Scores, estimator.Shape))) .ToArray(); } @@ -163,9 +163,9 @@ public static (MultiClassClassifierMetrics metrics, Transformer ( - x.metrics, - new Transformer(env, (TTransformer)x.model, data.Shape, estimator.Shape), - new DataView(env, x.scoredTestData, estimator.Shape))) + x.Metrics, + new Transformer(env, (TTransformer)x.Model, data.Shape, estimator.Shape), + new DataView(env, x.Scores, estimator.Shape))) .ToArray(); } @@ -279,9 +279,9 @@ public static (CalibratedBinaryClassificationMetrics metrics, Transformer ( - x.metrics, - new Transformer(env, (TTransformer)x.model, data.Shape, estimator.Shape), - new DataView(env, x.scoredTestData, estimator.Shape))) + x.Metrics, + new Transformer(env, (TTransformer)x.Model, data.Shape, estimator.Shape), + new DataView(env, x.Scores, estimator.Shape))) .ToArray(); } } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs index a0ef508417..dd05dfd549 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs @@ -428,12 +428,12 @@ private void CrossValidationOn(string dataPath) .Append(mlContext.MulticlassClassification.Trainers.StochasticDualCoordinateAscent()); // Split the data 90:10 into train and test sets, train and evaluate. - var (trainData, testData) = mlContext.MulticlassClassification.TrainTestSplit(data, testFraction: 0.1); + var split = mlContext.MulticlassClassification.TrainTestSplit(data, testFraction: 0.1); // Train the model. - var model = pipeline.Fit(trainData); + var model = pipeline.Fit(split.TrainSet); // Compute quality metrics on the test set. - var metrics = mlContext.MulticlassClassification.Evaluate(model.Transform(testData)); + var metrics = mlContext.MulticlassClassification.Evaluate(model.Transform(split.TestSet)); Console.WriteLine(metrics.AccuracyMicro); // Now run the 5-fold cross-validation experiment, using the same pipeline. @@ -441,7 +441,7 @@ private void CrossValidationOn(string dataPath) // The results object is an array of 5 elements. For each of the 5 folds, we have metrics, model and scored test data. // Let's compute the average micro-accuracy. - var microAccuracies = cvResults.Select(r => r.metrics.AccuracyMicro); + var microAccuracies = cvResults.Select(r => r.Metrics.AccuracyMicro); Console.WriteLine(microAccuracies.Average()); } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs b/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs index bc43d688fb..4b7fdbfaff 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs @@ -312,22 +312,22 @@ public void TestTrainTestSplit() // Let's test what train test properly works with seed. // In order to do that, let's split same dataset, but in one case we will use default seed value, // and in other case we set seed to be specific value. - var (simpleTrain, simpleTest) = mlContext.BinaryClassification.TrainTestSplit(input); - var (simpleTrainWithSeed, simpleTestWithSeed) = mlContext.BinaryClassification.TrainTestSplit(input, seed: 10); + var simpleSplit = mlContext.BinaryClassification.TrainTestSplit(input); + var splitWithSeed = mlContext.BinaryClassification.TrainTestSplit(input, seed: 10); // Since test fraction is 0.1, it's much faster to compare test subsets of split. - var simpleTestWorkClass = getWorkclass(simpleTest); + var simpleTestWorkClass = getWorkclass(simpleSplit.TestSet); - var simpleWithSeedTestWorkClass = getWorkclass(simpleTestWithSeed); + var simpleWithSeedTestWorkClass = getWorkclass(splitWithSeed.TestSet); // Validate we get different test sets. Assert.NotEqual(simpleTestWorkClass, simpleWithSeedTestWorkClass); // Now let's do same thing but with presence of stratificationColumn. // Rows with same values in this stratificationColumn should end up in same subset (train or test). // So let's break dataset by "Workclass" column. - var (stratTrain, stratTest) = mlContext.BinaryClassification.TrainTestSplit(input, stratificationColumn: "Workclass"); - var stratTrainWorkclass = getWorkclass(stratTrain); - var stratTestWorkClass = getWorkclass(stratTest); + var stratSplit = mlContext.BinaryClassification.TrainTestSplit(input, stratificationColumn: "Workclass"); + var stratTrainWorkclass = getWorkclass(stratSplit.TrainSet); + var stratTestWorkClass = getWorkclass(stratSplit.TestSet); // Let's get unique values for "Workclass" column from train subset. var uniqueTrain = stratTrainWorkclass.GroupBy(x => x.ToString()).Select(x => x.First()).ToList(); // and from test subset. @@ -337,9 +337,9 @@ public void TestTrainTestSplit() // Let's do same thing, but this time we will choose different seed. // Stratification column should still break dataset properly without same values in both subsets. - var (stratWithSeedTrain, stratWithSeedTest) = mlContext.BinaryClassification.TrainTestSplit(input, stratificationColumn:"Workclass", seed: 1000000); - var stratTrainWithSeedWorkclass = getWorkclass(stratWithSeedTrain); - var stratTestWithSeedWorkClass = getWorkclass(stratWithSeedTest); + var stratSeed = mlContext.BinaryClassification.TrainTestSplit(input, stratificationColumn:"Workclass", seed: 1000000); + var stratTrainWithSeedWorkclass = getWorkclass(stratSeed.TrainSet); + var stratTestWithSeedWorkClass = getWorkclass(stratSeed.TestSet); // Let's get unique values for "Workclass" column from train subset. var uniqueSeedTrain = stratTrainWithSeedWorkclass.GroupBy(x => x.ToString()).Select(x => x.First()).ToList(); // and from test subset. From 07900057b220641e1719dc118532ded1919df32e Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Mon, 11 Feb 2019 14:30:00 -0800 Subject: [PATCH 2/3] address small things --- src/Microsoft.ML.Data/TrainCatalog.cs | 20 +++++++++++-------- .../TrainingStaticExtensions.cs | 6 +++--- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/src/Microsoft.ML.Data/TrainCatalog.cs b/src/Microsoft.ML.Data/TrainCatalog.cs index 0dd62dfc25..bb9e404d68 100644 --- a/src/Microsoft.ML.Data/TrainCatalog.cs +++ b/src/Microsoft.ML.Data/TrainCatalog.cs @@ -43,7 +43,7 @@ public struct TrainTestData /// /// Training set. /// Testing set. - public TrainTestData(IDataView trainSet, IDataView testSet) + internal TrainTestData(IDataView trainSet, IDataView testSet) { TrainSet = trainSet; TestSet = testSet; @@ -98,7 +98,7 @@ protected internal struct CrossValidationResult /// public readonly ITransformer Model; /// - /// for scored fold. + /// Scored test set with for this fold. /// public readonly IDataView Scores; /// @@ -113,11 +113,14 @@ public CrossValidationResult(ITransformer model, IDataView scores, int fold) Fold = fold; } } - + /// + /// Results of running crossvalidation. + /// + /// Type of metric class public class CrossValidationResult where T : class { /// - /// Metrics for cross validation fold. + /// Metrics for this cross validation fold. /// public readonly T Metrics; /// @@ -125,7 +128,7 @@ public class CrossValidationResult where T : class /// public readonly ITransformer Model; /// - /// for scored fold. + /// Scored test set with for this fold. /// public readonly IDataView Scores; /// @@ -133,7 +136,7 @@ public class CrossValidationResult where T : class /// public readonly int Fold; - public CrossValidationResult(ITransformer model, T metrics, IDataView scores, int fold) + internal CrossValidationResult(ITransformer model, T metrics, IDataView scores, int fold) { Model = model; Metrics = metrics; @@ -341,13 +344,14 @@ public BinaryClassificationMetrics EvaluateNonCalibrated(IDataView data, string /// If the is not provided, the random numbers generated to create it, will use this seed as value. /// And if it is not provided, the default value will be used. /// Per-fold results: metrics, models, scored datasets. - public (BinaryClassificationMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidateNonCalibrated( + public CrossValidationResult[] CrossValidateNonCalibrated( IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null, uint? seed = null) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed); - return result.Select(x => (EvaluateNonCalibrated(x.Scores, labelColumn), x.Model, x.Scores)).ToArray(); + return result.Select(x => new CrossValidationResult(x.Model, + EvaluateNonCalibrated(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray(); } /// diff --git a/src/Microsoft.ML.StaticPipe/TrainingStaticExtensions.cs b/src/Microsoft.ML.StaticPipe/TrainingStaticExtensions.cs index 34ab6a03ac..f720524680 100644 --- a/src/Microsoft.ML.StaticPipe/TrainingStaticExtensions.cs +++ b/src/Microsoft.ML.StaticPipe/TrainingStaticExtensions.cs @@ -221,9 +221,9 @@ public static (BinaryClassificationMetrics metrics, Transformer ( - x.metrics, - new Transformer(env, (TTransformer)x.model, data.Shape, estimator.Shape), - new DataView(env, x.scoredTestData, estimator.Shape))) + x.Metrics, + new Transformer(env, (TTransformer)x.Model, data.Shape, estimator.Shape), + new DataView(env, x.Scores, estimator.Shape))) .ToArray(); } From a83b42094f59d04e0505c379844664f08afb998d Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Mon, 11 Feb 2019 15:04:42 -0800 Subject: [PATCH 3/3] make it build? --- src/Microsoft.ML.Data/TrainCatalog.cs | 18 +++++++++--------- .../TrainingStaticExtensions.cs | 8 ++++---- .../Prediction.cs | 6 +++--- .../Validation.cs | 8 ++++---- 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/Microsoft.ML.Data/TrainCatalog.cs b/src/Microsoft.ML.Data/TrainCatalog.cs index bb9e404d68..586ac66bd0 100644 --- a/src/Microsoft.ML.Data/TrainCatalog.cs +++ b/src/Microsoft.ML.Data/TrainCatalog.cs @@ -89,7 +89,7 @@ public TrainTestData TrainTestSplit(IDataView data, double testFraction = 0.1, s } /// - /// Results for specific cross validation fold. + /// Results for specific cross-validation fold. /// protected internal struct CrossValidationResult { @@ -114,23 +114,23 @@ public CrossValidationResult(ITransformer model, IDataView scores, int fold) } } /// - /// Results of running crossvalidation. + /// Results of running cross-validation. /// - /// Type of metric class - public class CrossValidationResult where T : class + /// Type of metric class. + public sealed class CrossValidationResult where T : class { /// - /// Metrics for this cross validation fold. + /// Metrics for this cross-validation fold. /// public readonly T Metrics; /// - /// Model trained during cross validation fold. + /// Model trained during cross-validation fold. /// public readonly ITransformer Model; /// - /// Scored test set with for this fold. + /// The scored hold-out set for this fold. /// - public readonly IDataView Scores; + public readonly IDataView ScoredHoldOutSet; /// /// Fold number. /// @@ -140,7 +140,7 @@ internal CrossValidationResult(ITransformer model, T metrics, IDataView scores, { Model = model; Metrics = metrics; - Scores = scores; + ScoredHoldOutSet = scores; Fold = fold; } } diff --git a/src/Microsoft.ML.StaticPipe/TrainingStaticExtensions.cs b/src/Microsoft.ML.StaticPipe/TrainingStaticExtensions.cs index f720524680..53992cf4ff 100644 --- a/src/Microsoft.ML.StaticPipe/TrainingStaticExtensions.cs +++ b/src/Microsoft.ML.StaticPipe/TrainingStaticExtensions.cs @@ -107,7 +107,7 @@ public static (RegressionMetrics metrics, Transformer ( x.Metrics, new Transformer(env, (TTransformer)x.Model, data.Shape, estimator.Shape), - new DataView(env, x.Scores, estimator.Shape))) + new DataView(env, x.ScoredHoldOutSet, estimator.Shape))) .ToArray(); } @@ -165,7 +165,7 @@ public static (MultiClassClassifierMetrics metrics, Transformer ( x.Metrics, new Transformer(env, (TTransformer)x.Model, data.Shape, estimator.Shape), - new DataView(env, x.Scores, estimator.Shape))) + new DataView(env, x.ScoredHoldOutSet, estimator.Shape))) .ToArray(); } @@ -223,7 +223,7 @@ public static (BinaryClassificationMetrics metrics, Transformer ( x.Metrics, new Transformer(env, (TTransformer)x.Model, data.Shape, estimator.Shape), - new DataView(env, x.Scores, estimator.Shape))) + new DataView(env, x.ScoredHoldOutSet, estimator.Shape))) .ToArray(); } @@ -281,7 +281,7 @@ public static (CalibratedBinaryClassificationMetrics metrics, Transformer ( x.Metrics, new Transformer(env, (TTransformer)x.Model, data.Shape, estimator.Shape), - new DataView(env, x.Scores, estimator.Shape))) + new DataView(env, x.ScoredHoldOutSet, estimator.Shape))) .ToArray(); } } diff --git a/test/Microsoft.ML.Functional.Tests/Prediction.cs b/test/Microsoft.ML.Functional.Tests/Prediction.cs index 7e0ff2eb44..24cc049e8f 100644 --- a/test/Microsoft.ML.Functional.Tests/Prediction.cs +++ b/test/Microsoft.ML.Functional.Tests/Prediction.cs @@ -24,7 +24,7 @@ public void ReconfigurablePrediction() // Get the dataset, create a train and test var data = mlContext.Data.CreateTextLoader(TestDatasets.housing.GetLoaderColumns(), hasHeader: true) .Read(BaseTestClass.GetDataPath(TestDatasets.housing.trainFilename)); - (var train, var test) = mlContext.BinaryClassification.TrainTestSplit(data, testFraction: 0.2); + var split = mlContext.BinaryClassification.TrainTestSplit(data, testFraction: 0.2); // Create a pipeline to train on the housing data var pipeline = mlContext.Transforms.Concatenate("Features", new string[] { @@ -33,9 +33,9 @@ public void ReconfigurablePrediction() .Append(mlContext.Transforms.CopyColumns("Label", "MedianHomeValue")) .Append(mlContext.Regression.Trainers.OrdinaryLeastSquares()); - var model = pipeline.Fit(train); + var model = pipeline.Fit(split.TrainSet); - var scoredTest = model.Transform(test); + var scoredTest = model.Transform(split.TestSet); var metrics = mlContext.Regression.Evaluate(scoredTest); Common.CheckMetrics(metrics); diff --git a/test/Microsoft.ML.Functional.Tests/Validation.cs b/test/Microsoft.ML.Functional.Tests/Validation.cs index b9bb617285..b04eff387a 100644 --- a/test/Microsoft.ML.Functional.Tests/Validation.cs +++ b/test/Microsoft.ML.Functional.Tests/Validation.cs @@ -40,14 +40,14 @@ void CrossValidation() var cvResult = mlContext.Regression.CrossValidate(data, pipeline, numFolds: 5); // Check that the results are valid - Assert.IsType(cvResult[0].metrics); - Assert.IsType>>(cvResult[0].model); - Assert.True(cvResult[0].scoredTestData is IDataView); + Assert.IsType(cvResult[0].Metrics); + Assert.IsType>>(cvResult[0].Model); + Assert.True(cvResult[0].ScoredHoldOutSet is IDataView); Assert.Equal(5, cvResult.Length); // And validate the metrics foreach (var result in cvResult) - Common.CheckMetrics(result.metrics); + Common.CheckMetrics(result.Metrics); } } }