diff --git a/docs/code/MlNetCookBook.md b/docs/code/MlNetCookBook.md index 1551e275d3..bb5e866c04 100644 --- a/docs/code/MlNetCookBook.md +++ b/docs/code/MlNetCookBook.md @@ -825,12 +825,12 @@ var pipeline = .Append(mlContext.MulticlassClassification.Trainers.StochasticDualCoordinateAscent()); // Split the data 90:10 into train and test sets, train and evaluate. -var (trainData, testData) = mlContext.MulticlassClassification.TrainTestSplit(data, testFraction: 0.1); +var split = mlContext.MulticlassClassification.TrainTestSplit(data, testFraction: 0.1); // Train the model. -var model = pipeline.Fit(trainData); +var model = pipeline.Fit(split.TrainSet); // Compute quality metrics on the test set. -var metrics = mlContext.MulticlassClassification.Evaluate(model.Transform(testData)); +var metrics = mlContext.MulticlassClassification.Evaluate(model.Transform(split.TestSet)); Console.WriteLine(metrics.AccuracyMicro); // Now run the 5-fold cross-validation experiment, using the same pipeline. @@ -838,7 +838,7 @@ var cvResults = mlContext.MulticlassClassification.CrossValidate(data, pipeline, // The results object is an array of 5 elements. For each of the 5 folds, we have metrics, model and scored test data. // Let's compute the average micro-accuracy. -var microAccuracies = cvResults.Select(r => r.metrics.AccuracyMicro); +var microAccuracies = cvResults.Select(r => r.Metrics.AccuracyMicro); Console.WriteLine(microAccuracies.Average()); ``` diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs index eb46a9683a..ed75040e77 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs @@ -43,7 +43,7 @@ public static void Calibration() var data = reader.Read(dataFile); // Split the dataset into two parts: one used for training, the other to train the calibrator - var (trainData, calibratorTrainingData) = mlContext.BinaryClassification.TrainTestSplit(data, testFraction: 0.1); + var split = mlContext.BinaryClassification.TrainTestSplit(data, testFraction: 0.1); // Featurize the text column through the FeaturizeText API. // Then append the StochasticDualCoordinateAscentBinary binary classifier, setting the "Label" column as the label of the dataset, and @@ -56,12 +56,12 @@ public static void Calibration() loss: new HingeLoss())); // By specifying loss: new HingeLoss(), StochasticDualCoordinateAscent will train a support vector machine (SVM). // Fit the pipeline, and get a transformer that knows how to score new data. - var transformer = pipeline.Fit(trainData); + var transformer = pipeline.Fit(split.TrainSet); IPredictor model = transformer.LastTransformer.Model; // Let's score the new data. The score will give us a numerical estimation of the chance that the particular sample // bears positive sentiment. This estimate is relative to the numbers obtained. - var scoredData = transformer.Transform(calibratorTrainingData); + var scoredData = transformer.Transform(split.TestSet); var scoredDataPreview = scoredData.Preview(); PrintRowViewValues(scoredDataPreview); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/LogisticRegression.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/LogisticRegression.cs index 4b25bdc805..f85b68bb7e 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/LogisticRegression.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/LogisticRegression.cs @@ -57,7 +57,7 @@ public static void LogisticRegression() IDataView data = reader.Read(dataFilePath); - var (trainData, testData) = ml.BinaryClassification.TrainTestSplit(data, testFraction: 0.2); + var split = ml.BinaryClassification.TrainTestSplit(data, testFraction: 0.2); var pipeline = ml.Transforms.Concatenate("Text", "workclass", "education", "marital-status", "relationship", "ethnicity", "sex", "native-country") @@ -66,9 +66,9 @@ public static void LogisticRegression() "education-num", "capital-gain", "capital-loss", "hours-per-week")) .Append(ml.BinaryClassification.Trainers.LogisticRegression()); - var model = pipeline.Fit(trainData); + var model = pipeline.Fit(split.TrainSet); - var dataWithPredictions = model.Transform(testData); + var dataWithPredictions = model.Transform(split.TestSet); var metrics = ml.BinaryClassification.Evaluate(dataWithPredictions); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/SDCA.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/SDCA.cs index 341827bde3..0f06b9c905 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/SDCA.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/SDCA.cs @@ -54,7 +54,7 @@ public static void SDCA_BinaryClassification() // Step 3: Run Cross-Validation on this pipeline. var cvResults = mlContext.BinaryClassification.CrossValidate(data, pipeline, labelColumn: "Sentiment"); - var accuracies = cvResults.Select(r => r.metrics.Accuracy); + var accuracies = cvResults.Select(r => r.Metrics.Accuracy); Console.WriteLine(accuracies.Average()); // If we wanted to specify more advanced parameters for the algorithm, @@ -70,7 +70,7 @@ public static void SDCA_BinaryClassification() // Run Cross-Validation on this second pipeline. var cvResults_advancedPipeline = mlContext.BinaryClassification.CrossValidate(data, pipeline, labelColumn: "Sentiment", numFolds: 3); - accuracies = cvResults_advancedPipeline.Select(r => r.metrics.Accuracy); + accuracies = cvResults_advancedPipeline.Select(r => r.Metrics.Accuracy); Console.WriteLine(accuracies.Average()); } diff --git a/src/Microsoft.ML.Data/TrainCatalog.cs b/src/Microsoft.ML.Data/TrainCatalog.cs index d8b4c53265..586ac66bd0 100644 --- a/src/Microsoft.ML.Data/TrainCatalog.cs +++ b/src/Microsoft.ML.Data/TrainCatalog.cs @@ -25,6 +25,31 @@ public abstract class TrainCatalogBase [BestFriend] internal IHostEnvironment Environment => Host; + /// <summary> + /// A pair of datasets, for the train and test set. + /// </summary> + public struct TrainTestData + { + /// <summary> + /// Training set. + /// </summary> + public readonly IDataView TrainSet; + /// <summary> + /// Testing set. + /// </summary> + public readonly IDataView TestSet; + /// <summary> + /// Create pair of datasets. + /// </summary> + /// <param name="trainSet">Training set.</param> + /// <param name="testSet">Testing set.</param> + internal TrainTestData(IDataView trainSet, IDataView testSet) + { + TrainSet = trainSet; + TestSet = testSet; + } + } + /// <summary> /// Split the dataset into the train set and test set according to the given fraction. /// Respects the <paramref name="stratificationColumn"/> if provided. @@ -37,8 +62,7 @@ public abstract class TrainCatalogBase /// <param name="seed">Optional parameter used in combination with the <paramref name="stratificationColumn"/>. /// If the <paramref name="stratificationColumn"/> is not provided, the random numbers generated to create it, will use this seed as value. /// And if it is not provided, the default value will be used.</param> - /// <returns>A pair of datasets, for the train and test set.</returns> - public (IDataView trainSet, IDataView testSet) TrainTestSplit(IDataView data, double testFraction = 0.1, string stratificationColumn = null, uint? seed = null) + public TrainTestData TrainTestSplit(IDataView data, double testFraction = 0.1, string stratificationColumn = null, uint? seed = null) { Host.CheckValue(data, nameof(data)); Host.CheckParam(0 < testFraction && testFraction < 1, nameof(testFraction), "Must be between 0 and 1 exclusive"); @@ -61,14 +85,71 @@ public abstract class TrainCatalogBase Complement = false }, data); - return (trainFilter, testFilter); + return new TrainTestData(trainFilter, testFilter); + } + + /// <summary> + /// Results for specific cross-validation fold. + /// </summary> + protected internal struct CrossValidationResult + { + /// <summary> + /// Model trained during cross validation fold. + /// </summary> + public readonly ITransformer Model; + /// <summary> + /// Scored test set with <see cref="Model"/> for this fold. + /// </summary> + public readonly IDataView Scores; + /// <summary> + /// Fold number. + /// </summary> + public readonly int Fold; + + public CrossValidationResult(ITransformer model, IDataView scores, int fold) + { + Model = model; + Scores = scores; + Fold = fold; + } + } + /// <summary> + /// Results of running cross-validation. + /// </summary> + /// <typeparam name="T">Type of metric class.</typeparam> + public sealed class CrossValidationResult<T> where T : class + { + /// <summary> + /// Metrics for this cross-validation fold. + /// </summary> + public readonly T Metrics; + /// <summary> + /// Model trained during cross-validation fold. + /// </summary> + public readonly ITransformer Model; + /// <summary> + /// The scored hold-out set for this fold. + /// </summary> + public readonly IDataView ScoredHoldOutSet; + /// <summary> + /// Fold number. + /// </summary> + public readonly int Fold; + + internal CrossValidationResult(ITransformer model, T metrics, IDataView scores, int fold) + { + Model = model; + Metrics = metrics; + ScoredHoldOutSet = scores; + Fold = fold; + } } /// <summary> /// Train the <paramref name="estimator"/> on <paramref name="numFolds"/> folds of the data sequentially. /// Return each model and each scored test dataset. /// </summary> - protected internal (IDataView scoredTestSet, ITransformer model)[] CrossValidateTrain(IDataView data, IEstimator<ITransformer> estimator, + protected internal CrossValidationResult[] CrossValidateTrain(IDataView data, IEstimator<ITransformer> estimator, int numFolds, string stratificationColumn, uint? seed = null) { Host.CheckValue(data, nameof(data)); @@ -78,7 +159,7 @@ protected internal (IDataView scoredTestSet, ITransformer model)[] CrossValidate EnsureStratificationColumn(ref data, ref stratificationColumn, seed); - Func<int, (IDataView scores, ITransformer model)> foldFunction = + Func<int, CrossValidationResult> foldFunction = fold => { var trainFilter = new RangeFilter(Host, new RangeFilter.Options @@ -98,17 +179,17 @@ protected internal (IDataView scoredTestSet, ITransformer model)[] CrossValidate var model = estimator.Fit(trainFilter); var scoredTest = model.Transform(testFilter); - return (scoredTest, model); + return new CrossValidationResult(model, scoredTest, fold); }; // Sequential per-fold training. // REVIEW: we could have a parallel implementation here. We would need to // spawn off a separate host per fold in that case. - var result = new List<(IDataView scores, ITransformer model)>(); + var result = new CrossValidationResult[numFolds]; for (int fold = 0; fold < numFolds; fold++) - result.Add(foldFunction(fold)); + result[fold] = foldFunction(fold); - return result.ToArray(); + return result; } protected internal TrainCatalogBase(IHostEnvironment env, string registrationName) @@ -263,13 +344,14 @@ public BinaryClassificationMetrics EvaluateNonCalibrated(IDataView data, string /// If the <paramref name="stratificationColumn"/> is not provided, the random numbers generated to create it, will use this seed as value. /// And if it is not provided, the default value will be used.</param> /// <returns>Per-fold results: metrics, models, scored datasets.</returns> - public (BinaryClassificationMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidateNonCalibrated( + public CrossValidationResult<BinaryClassificationMetrics>[] CrossValidateNonCalibrated( IDataView data, IEstimator<ITransformer> estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null, uint? seed = null) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed); - return result.Select(x => (EvaluateNonCalibrated(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); + return result.Select(x => new CrossValidationResult<BinaryClassificationMetrics>(x.Model, + EvaluateNonCalibrated(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray(); } /// <summary> @@ -287,13 +369,14 @@ public BinaryClassificationMetrics EvaluateNonCalibrated(IDataView data, string /// train to the test set.</remarks> /// <param name="seed">If <paramref name="stratificationColumn"/> not present in dataset we will generate random filled column based on provided <paramref name="seed"/>.</param> /// <returns>Per-fold results: metrics, models, scored datasets.</returns> - public (CalibratedBinaryClassificationMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( + public CrossValidationResult<CalibratedBinaryClassificationMetrics>[] CrossValidate( IDataView data, IEstimator<ITransformer> estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null, uint? seed = null) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed); - return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); + return result.Select(x => new CrossValidationResult<CalibratedBinaryClassificationMetrics>(x.Model, + Evaluate(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray(); } } @@ -369,12 +452,13 @@ public ClusteringMetrics Evaluate(IDataView data, /// If the <paramref name="stratificationColumn"/> is not provided, the random numbers generated to create it, will use this seed as value. /// And if it is not provided, the default value will be used.</param> /// <returns>Per-fold results: metrics, models, scored datasets.</returns> - public (ClusteringMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( + public CrossValidationResult<ClusteringMetrics>[] CrossValidate( IDataView data, IEstimator<ITransformer> estimator, int numFolds = 5, string labelColumn = null, string featuresColumn = null, string stratificationColumn = null, uint? seed = null) { var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed); - return result.Select(x => (Evaluate(x.scoredTestSet, label: labelColumn, features: featuresColumn), x.model, x.scoredTestSet)).ToArray(); + return result.Select(x => new CrossValidationResult<ClusteringMetrics>(x.Model, + Evaluate(x.Scores, label: labelColumn, features: featuresColumn), x.Scores, x.Fold)).ToArray(); } } @@ -444,13 +528,14 @@ public MultiClassClassifierMetrics Evaluate(IDataView data, string label = Defau /// If the <paramref name="stratificationColumn"/> is not provided, the random numbers generated to create it, will use this seed as value. /// And if it is not provided, the default value will be used.</param> /// <returns>Per-fold results: metrics, models, scored datasets.</returns> - public (MultiClassClassifierMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( + public CrossValidationResult<MultiClassClassifierMetrics>[] CrossValidate( IDataView data, IEstimator<ITransformer> estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null, uint? seed = null) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed); - return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); + return result.Select(x => new CrossValidationResult<MultiClassClassifierMetrics>(x.Model, + Evaluate(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray(); } } @@ -511,13 +596,14 @@ public RegressionMetrics Evaluate(IDataView data, string label = DefaultColumnNa /// If the <paramref name="stratificationColumn"/> is not provided, the random numbers generated to create it, will use this seed as value. /// And if it is not provided, the default value will be used.</param> /// <returns>Per-fold results: metrics, models, scored datasets.</returns> - public (RegressionMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( + public CrossValidationResult<RegressionMetrics>[] CrossValidate( IDataView data, IEstimator<ITransformer> estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null, uint? seed = null) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed); - return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); + return result.Select(x => new CrossValidationResult<RegressionMetrics>(x.Model, + Evaluate(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray(); } } diff --git a/src/Microsoft.ML.Recommender/RecommenderCatalog.cs b/src/Microsoft.ML.Recommender/RecommenderCatalog.cs index 4c98db5a74..637e4962ea 100644 --- a/src/Microsoft.ML.Recommender/RecommenderCatalog.cs +++ b/src/Microsoft.ML.Recommender/RecommenderCatalog.cs @@ -128,13 +128,13 @@ public RegressionMetrics Evaluate(IDataView data, string label = DefaultColumnNa /// If the <paramref name="stratificationColumn"/> is not provided, the random numbers generated to create it, will use this seed as value. /// And if it is not provided, the default value will be used.</param> /// <returns>Per-fold results: metrics, models, scored datasets.</returns> - public (RegressionMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( + public CrossValidationResult<RegressionMetrics>[] CrossValidate( IDataView data, IEstimator<ITransformer> estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null, uint? seed = null) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed); - return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); + return result.Select(x => new CrossValidationResult<RegressionMetrics>(x.Model, Evaluate(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray(); } } } diff --git a/src/Microsoft.ML.StaticPipe/TrainingStaticExtensions.cs b/src/Microsoft.ML.StaticPipe/TrainingStaticExtensions.cs index 357246a658..53992cf4ff 100644 --- a/src/Microsoft.ML.StaticPipe/TrainingStaticExtensions.cs +++ b/src/Microsoft.ML.StaticPipe/TrainingStaticExtensions.cs @@ -49,8 +49,8 @@ public static (DataView<T> trainSet, DataView<T> testSet) TrainTestSplit<T>(this stratName = indexer.Get(column); } - var (trainData, testData) = catalog.TrainTestSplit(data.AsDynamic, testFraction, stratName, seed); - return (new DataView<T>(env, trainData, data.Shape), new DataView<T>(env, testData, data.Shape)); + var split = catalog.TrainTestSplit(data.AsDynamic, testFraction, stratName, seed); + return (new DataView<T>(env, split.TrainSet, data.Shape), new DataView<T>(env, split.TestSet, data.Shape)); } /// <summary> @@ -105,9 +105,9 @@ public static (RegressionMetrics metrics, Transformer<TInShape, TOutShape, TTran var results = catalog.CrossValidate(data.AsDynamic, estimator.AsDynamic, numFolds, labelName, stratName, seed); return results.Select(x => ( - x.metrics, - new Transformer<TInShape, TOutShape, TTransformer>(env, (TTransformer)x.model, data.Shape, estimator.Shape), - new DataView<TOutShape>(env, x.scoredTestData, estimator.Shape))) + x.Metrics, + new Transformer<TInShape, TOutShape, TTransformer>(env, (TTransformer)x.Model, data.Shape, estimator.Shape), + new DataView<TOutShape>(env, x.ScoredHoldOutSet, estimator.Shape))) .ToArray(); } @@ -163,9 +163,9 @@ public static (MultiClassClassifierMetrics metrics, Transformer<TInShape, TOutSh var results = catalog.CrossValidate(data.AsDynamic, estimator.AsDynamic, numFolds, labelName, stratName, seed); return results.Select(x => ( - x.metrics, - new Transformer<TInShape, TOutShape, TTransformer>(env, (TTransformer)x.model, data.Shape, estimator.Shape), - new DataView<TOutShape>(env, x.scoredTestData, estimator.Shape))) + x.Metrics, + new Transformer<TInShape, TOutShape, TTransformer>(env, (TTransformer)x.Model, data.Shape, estimator.Shape), + new DataView<TOutShape>(env, x.ScoredHoldOutSet, estimator.Shape))) .ToArray(); } @@ -221,9 +221,9 @@ public static (BinaryClassificationMetrics metrics, Transformer<TInShape, TOutSh var results = catalog.CrossValidateNonCalibrated(data.AsDynamic, estimator.AsDynamic, numFolds, labelName, stratName, seed); return results.Select(x => ( - x.metrics, - new Transformer<TInShape, TOutShape, TTransformer>(env, (TTransformer)x.model, data.Shape, estimator.Shape), - new DataView<TOutShape>(env, x.scoredTestData, estimator.Shape))) + x.Metrics, + new Transformer<TInShape, TOutShape, TTransformer>(env, (TTransformer)x.Model, data.Shape, estimator.Shape), + new DataView<TOutShape>(env, x.ScoredHoldOutSet, estimator.Shape))) .ToArray(); } @@ -279,9 +279,9 @@ public static (CalibratedBinaryClassificationMetrics metrics, Transformer<TInSha var results = catalog.CrossValidate(data.AsDynamic, estimator.AsDynamic, numFolds, labelName, stratName, seed); return results.Select(x => ( - x.metrics, - new Transformer<TInShape, TOutShape, TTransformer>(env, (TTransformer)x.model, data.Shape, estimator.Shape), - new DataView<TOutShape>(env, x.scoredTestData, estimator.Shape))) + x.Metrics, + new Transformer<TInShape, TOutShape, TTransformer>(env, (TTransformer)x.Model, data.Shape, estimator.Shape), + new DataView<TOutShape>(env, x.ScoredHoldOutSet, estimator.Shape))) .ToArray(); } } diff --git a/test/Microsoft.ML.Functional.Tests/Prediction.cs b/test/Microsoft.ML.Functional.Tests/Prediction.cs index 7e0ff2eb44..24cc049e8f 100644 --- a/test/Microsoft.ML.Functional.Tests/Prediction.cs +++ b/test/Microsoft.ML.Functional.Tests/Prediction.cs @@ -24,7 +24,7 @@ public void ReconfigurablePrediction() // Get the dataset, create a train and test var data = mlContext.Data.CreateTextLoader(TestDatasets.housing.GetLoaderColumns(), hasHeader: true) .Read(BaseTestClass.GetDataPath(TestDatasets.housing.trainFilename)); - (var train, var test) = mlContext.BinaryClassification.TrainTestSplit(data, testFraction: 0.2); + var split = mlContext.BinaryClassification.TrainTestSplit(data, testFraction: 0.2); // Create a pipeline to train on the housing data var pipeline = mlContext.Transforms.Concatenate("Features", new string[] { @@ -33,9 +33,9 @@ public void ReconfigurablePrediction() .Append(mlContext.Transforms.CopyColumns("Label", "MedianHomeValue")) .Append(mlContext.Regression.Trainers.OrdinaryLeastSquares()); - var model = pipeline.Fit(train); + var model = pipeline.Fit(split.TrainSet); - var scoredTest = model.Transform(test); + var scoredTest = model.Transform(split.TestSet); var metrics = mlContext.Regression.Evaluate(scoredTest); Common.CheckMetrics(metrics); diff --git a/test/Microsoft.ML.Functional.Tests/Validation.cs b/test/Microsoft.ML.Functional.Tests/Validation.cs index b9bb617285..b04eff387a 100644 --- a/test/Microsoft.ML.Functional.Tests/Validation.cs +++ b/test/Microsoft.ML.Functional.Tests/Validation.cs @@ -40,14 +40,14 @@ void CrossValidation() var cvResult = mlContext.Regression.CrossValidate(data, pipeline, numFolds: 5); // Check that the results are valid - Assert.IsType<RegressionMetrics>(cvResult[0].metrics); - Assert.IsType<TransformerChain<RegressionPredictionTransformer<OlsLinearRegressionModelParameters>>>(cvResult[0].model); - Assert.True(cvResult[0].scoredTestData is IDataView); + Assert.IsType<RegressionMetrics>(cvResult[0].Metrics); + Assert.IsType<TransformerChain<RegressionPredictionTransformer<OlsLinearRegressionModelParameters>>>(cvResult[0].Model); + Assert.True(cvResult[0].ScoredHoldOutSet is IDataView); Assert.Equal(5, cvResult.Length); // And validate the metrics foreach (var result in cvResult) - Common.CheckMetrics(result.metrics); + Common.CheckMetrics(result.Metrics); } } } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs index a0ef508417..dd05dfd549 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs @@ -428,12 +428,12 @@ private void CrossValidationOn(string dataPath) .Append(mlContext.MulticlassClassification.Trainers.StochasticDualCoordinateAscent()); // Split the data 90:10 into train and test sets, train and evaluate. - var (trainData, testData) = mlContext.MulticlassClassification.TrainTestSplit(data, testFraction: 0.1); + var split = mlContext.MulticlassClassification.TrainTestSplit(data, testFraction: 0.1); // Train the model. - var model = pipeline.Fit(trainData); + var model = pipeline.Fit(split.TrainSet); // Compute quality metrics on the test set. - var metrics = mlContext.MulticlassClassification.Evaluate(model.Transform(testData)); + var metrics = mlContext.MulticlassClassification.Evaluate(model.Transform(split.TestSet)); Console.WriteLine(metrics.AccuracyMicro); // Now run the 5-fold cross-validation experiment, using the same pipeline. @@ -441,7 +441,7 @@ private void CrossValidationOn(string dataPath) // The results object is an array of 5 elements. For each of the 5 folds, we have metrics, model and scored test data. // Let's compute the average micro-accuracy. - var microAccuracies = cvResults.Select(r => r.metrics.AccuracyMicro); + var microAccuracies = cvResults.Select(r => r.Metrics.AccuracyMicro); Console.WriteLine(microAccuracies.Average()); } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs b/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs index bc43d688fb..4b7fdbfaff 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs @@ -312,22 +312,22 @@ public void TestTrainTestSplit() // Let's test what train test properly works with seed. // In order to do that, let's split same dataset, but in one case we will use default seed value, // and in other case we set seed to be specific value. - var (simpleTrain, simpleTest) = mlContext.BinaryClassification.TrainTestSplit(input); - var (simpleTrainWithSeed, simpleTestWithSeed) = mlContext.BinaryClassification.TrainTestSplit(input, seed: 10); + var simpleSplit = mlContext.BinaryClassification.TrainTestSplit(input); + var splitWithSeed = mlContext.BinaryClassification.TrainTestSplit(input, seed: 10); // Since test fraction is 0.1, it's much faster to compare test subsets of split. - var simpleTestWorkClass = getWorkclass(simpleTest); + var simpleTestWorkClass = getWorkclass(simpleSplit.TestSet); - var simpleWithSeedTestWorkClass = getWorkclass(simpleTestWithSeed); + var simpleWithSeedTestWorkClass = getWorkclass(splitWithSeed.TestSet); // Validate we get different test sets. Assert.NotEqual(simpleTestWorkClass, simpleWithSeedTestWorkClass); // Now let's do same thing but with presence of stratificationColumn. // Rows with same values in this stratificationColumn should end up in same subset (train or test). // So let's break dataset by "Workclass" column. - var (stratTrain, stratTest) = mlContext.BinaryClassification.TrainTestSplit(input, stratificationColumn: "Workclass"); - var stratTrainWorkclass = getWorkclass(stratTrain); - var stratTestWorkClass = getWorkclass(stratTest); + var stratSplit = mlContext.BinaryClassification.TrainTestSplit(input, stratificationColumn: "Workclass"); + var stratTrainWorkclass = getWorkclass(stratSplit.TrainSet); + var stratTestWorkClass = getWorkclass(stratSplit.TestSet); // Let's get unique values for "Workclass" column from train subset. var uniqueTrain = stratTrainWorkclass.GroupBy(x => x.ToString()).Select(x => x.First()).ToList(); // and from test subset. @@ -337,9 +337,9 @@ public void TestTrainTestSplit() // Let's do same thing, but this time we will choose different seed. // Stratification column should still break dataset properly without same values in both subsets. - var (stratWithSeedTrain, stratWithSeedTest) = mlContext.BinaryClassification.TrainTestSplit(input, stratificationColumn:"Workclass", seed: 1000000); - var stratTrainWithSeedWorkclass = getWorkclass(stratWithSeedTrain); - var stratTestWithSeedWorkClass = getWorkclass(stratWithSeedTest); + var stratSeed = mlContext.BinaryClassification.TrainTestSplit(input, stratificationColumn:"Workclass", seed: 1000000); + var stratTrainWithSeedWorkclass = getWorkclass(stratSeed.TrainSet); + var stratTestWithSeedWorkClass = getWorkclass(stratSeed.TestSet); // Let's get unique values for "Workclass" column from train subset. var uniqueSeedTrain = stratTrainWithSeedWorkclass.GroupBy(x => x.ToString()).Select(x => x.First()).ToList(); // and from test subset.