Skip to content

Commit 1599657

Browse files
committed
Polish train catalog
1 parent 3af9a5d commit 1599657

File tree

5 files changed

+32
-32
lines changed

5 files changed

+32
-32
lines changed

src/Microsoft.ML.Data/TrainCatalog.cs

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -227,48 +227,48 @@ public BinaryClassificationMetrics EvaluateNonCalibrated(IDataView data, string
227227

228228
/// <summary>
229229
/// Run cross-validation over <paramref name="numberOfFolds"/> folds of <paramref name="data"/>, by fitting <paramref name="estimator"/>,
230-
/// and respecting <paramref name="samplingKeyColumnName"/> if provided.
230+
/// and respecting <paramref name="partitionKeyColumnName"/> if provided.
231231
/// Then evaluate each sub-model against <paramref name="labelColumnName"/> and return metrics.
232232
/// </summary>
233233
/// <param name="data">The data to run cross-validation on.</param>
234234
/// <param name="estimator">The estimator to fit.</param>
235235
/// <param name="numberOfFolds">Number of cross-validation folds.</param>
236236
/// <param name="labelColumnName">The label column (for evaluation).</param>
237-
/// <param name="samplingKeyColumnName">Name of a column to use for grouping rows. If two examples share the same value of the <paramref name="samplingKeyColumnName"/>,
237+
/// <param name="partitionKeyColumnName">Name of a column to use for grouping rows. If two examples share the same value of the <paramref name="partitionKeyColumnName"/>,
238238
/// they are guaranteed to appear in the same subset (train or test). This can be used to ensure no label leakage from the train to the test set.
239239
/// If <see langword="null"/> no row grouping will be performed.</param>
240240
/// <param name="seed">Seed for the random number generator used to select rows for cross-validation folds.</param>
241241
/// <returns>Per-fold results: metrics, models, scored datasets.</returns>
242242
public CrossValidationResult<BinaryClassificationMetrics>[] CrossValidateNonCalibrated(
243243
IDataView data, IEstimator<ITransformer> estimator, int numberOfFolds = 5, string labelColumnName = DefaultColumnNames.Label,
244-
string samplingKeyColumnName = null, int? seed = null)
244+
string partitionKeyColumnName = null, int? seed = null)
245245
{
246246
Environment.CheckNonEmpty(labelColumnName, nameof(labelColumnName));
247-
var result = CrossValidateTrain(data, estimator, numberOfFolds, samplingKeyColumnName, seed);
247+
var result = CrossValidateTrain(data, estimator, numberOfFolds, partitionKeyColumnName, seed);
248248
return result.Select(x => new CrossValidationResult<BinaryClassificationMetrics>(x.Model,
249249
EvaluateNonCalibrated(x.Scores, labelColumnName), x.Scores, x.Fold)).ToArray();
250250
}
251251

252252
/// <summary>
253253
/// Run cross-validation over <paramref name="numberOfFolds"/> folds of <paramref name="data"/>, by fitting <paramref name="estimator"/>,
254-
/// and respecting <paramref name="samplingKeyColumnName"/> if provided.
254+
/// and respecting <paramref name="partitionKeyColumnName"/> if provided.
255255
/// Then evaluate each sub-model against <paramref name="labelColumnName"/> and return metrics.
256256
/// </summary>
257257
/// <param name="data">The data to run cross-validation on.</param>
258258
/// <param name="estimator">The estimator to fit.</param>
259259
/// <param name="numberOfFolds">Number of cross-validation folds.</param>
260260
/// <param name="labelColumnName">The label column (for evaluation).</param>
261-
/// <param name="samplingKeyColumnName">Name of a column to use for grouping rows. If two examples share the same value of the <paramref name="samplingKeyColumnName"/>,
261+
/// <param name="partitionKeyColumnName">Name of a column to use for grouping rows. If two examples share the same value of the <paramref name="partitionKeyColumnName"/>,
262262
/// they are guaranteed to appear in the same subset (train or test). This can be used to ensure no label leakage from the train to the test set.
263263
/// If <see langword="null"/> no row grouping will be performed.</param>
264264
/// <param name="seed">Seed for the random number generator used to select rows for cross-validation folds.</param>
265265
/// <returns>Per-fold results: metrics, models, scored datasets.</returns>
266266
public CrossValidationResult<CalibratedBinaryClassificationMetrics>[] CrossValidate(
267267
IDataView data, IEstimator<ITransformer> estimator, int numberOfFolds = 5, string labelColumnName = DefaultColumnNames.Label,
268-
string samplingKeyColumnName = null, int? seed = null)
268+
string partitionKeyColumnName = null, int? seed = null)
269269
{
270270
Environment.CheckNonEmpty(labelColumnName, nameof(labelColumnName));
271-
var result = CrossValidateTrain(data, estimator, numberOfFolds, samplingKeyColumnName, seed);
271+
var result = CrossValidateTrain(data, estimator, numberOfFolds, partitionKeyColumnName, seed);
272272
return result.Select(x => new CrossValidationResult<CalibratedBinaryClassificationMetrics>(x.Model,
273273
Evaluate(x.Scores, labelColumnName), x.Scores, x.Fold)).ToArray();
274274
}
@@ -431,23 +431,23 @@ public ClusteringMetrics Evaluate(IDataView data,
431431

432432
/// <summary>
433433
/// Run cross-validation over <paramref name="numberOfFolds"/> folds of <paramref name="data"/>, by fitting <paramref name="estimator"/>,
434-
/// and respecting <paramref name="samplingKeyColumnName"/> if provided.
434+
/// and respecting <paramref name="partitionKeyColumnName"/> if provided.
435435
/// Then evaluate each sub-model against <paramref name="labelColumnName"/> and return metrics.
436436
/// </summary>
437437
/// <param name="data">The data to run cross-validation on.</param>
438438
/// <param name="estimator">The estimator to fit.</param>
439439
/// <param name="numberOfFolds">Number of cross-validation folds.</param>
440440
/// <param name="labelColumnName">Optional label column for evaluation (clustering tasks may not always have a label).</param>
441441
/// <param name="featuresColumnName">Optional features column for evaluation (needed for calculating Dbi metric)</param>
442-
/// <param name="samplingKeyColumnName">Name of a column to use for grouping rows. If two examples share the same value of the <paramref name="samplingKeyColumnName"/>,
442+
/// <param name="partitionKeyColumnName">Name of a column to use for grouping rows. If two examples share the same value of the <paramref name="partitionKeyColumnName"/>,
443443
/// they are guaranteed to appear in the same subset (train or test). This can be used to ensure no label leakage from the train to the test set.
444444
/// If <see langword="null"/> no row grouping will be performed.</param>
445445
/// <param name="seed">Seed for the random number generator used to select rows for cross-validation folds.</param>
446446
public CrossValidationResult<ClusteringMetrics>[] CrossValidate(
447447
IDataView data, IEstimator<ITransformer> estimator, int numberOfFolds = 5, string labelColumnName = null, string featuresColumnName = null,
448-
string samplingKeyColumnName = null, int? seed = null)
448+
string partitionKeyColumnName = null, int? seed = null)
449449
{
450-
var result = CrossValidateTrain(data, estimator, numberOfFolds, samplingKeyColumnName, seed);
450+
var result = CrossValidateTrain(data, estimator, numberOfFolds, partitionKeyColumnName, seed);
451451
return result.Select(x => new CrossValidationResult<ClusteringMetrics>(x.Model,
452452
Evaluate(x.Scores, labelColumnName: labelColumnName, featureColumnName: featuresColumnName), x.Scores, x.Fold)).ToArray();
453453
}
@@ -484,46 +484,46 @@ internal MulticlassClassificationTrainers(MulticlassClassificationCatalog catalo
484484
/// <param name="labelColumnName">The name of the label column in <paramref name="data"/>.</param>
485485
/// <param name="scoreColumnName">The name of the score column in <paramref name="data"/>.</param>
486486
/// <param name="predictedLabelColumnName">The name of the predicted label column in <paramref name="data"/>.</param>
487-
/// <param name="topK">If given a positive value, the <see cref="MulticlassClassificationMetrics.TopKAccuracy"/> will be filled with
487+
/// <param name="topPredictionCount">If given a positive value, the <see cref="MulticlassClassificationMetrics.TopKAccuracy"/> will be filled with
488488
/// the top-K accuracy, that is, the accuracy assuming we consider an example with the correct class within
489489
/// the top-K values as being stored "correctly."</param>
490490
/// <returns>The evaluation results for these calibrated outputs.</returns>
491491
public MulticlassClassificationMetrics Evaluate(IDataView data, string labelColumnName = DefaultColumnNames.Label, string scoreColumnName = DefaultColumnNames.Score,
492-
string predictedLabelColumnName = DefaultColumnNames.PredictedLabel, int topK = 0)
492+
string predictedLabelColumnName = DefaultColumnNames.PredictedLabel, int topPredictionCount = 0)
493493
{
494494
Environment.CheckValue(data, nameof(data));
495495
Environment.CheckNonEmpty(labelColumnName, nameof(labelColumnName));
496496
Environment.CheckNonEmpty(scoreColumnName, nameof(scoreColumnName));
497497
Environment.CheckNonEmpty(predictedLabelColumnName, nameof(predictedLabelColumnName));
498498

499499
var args = new MulticlassClassificationEvaluator.Arguments() { };
500-
if (topK > 0)
501-
args.OutputTopKAcc = topK;
500+
if (topPredictionCount > 0)
501+
args.OutputTopKAcc = topPredictionCount;
502502
var eval = new MulticlassClassificationEvaluator(Environment, args);
503503
return eval.Evaluate(data, labelColumnName, scoreColumnName, predictedLabelColumnName);
504504
}
505505

506506
/// <summary>
507507
/// Run cross-validation over <paramref name="numberOfFolds"/> folds of <paramref name="data"/>, by fitting <paramref name="estimator"/>,
508-
/// and respecting <paramref name="samplingKeyColumnName"/> if provided.
508+
/// and respecting <paramref name="partitionKeyColumnName"/> if provided.
509509
/// Then evaluate each sub-model against <paramref name="labelColumnName"/> and return metrics.
510510
/// </summary>
511511
/// <param name="data">The data to run cross-validation on.</param>
512512
/// <param name="estimator">The estimator to fit.</param>
513513
/// <param name="numberOfFolds">Number of cross-validation folds.</param>
514514
/// <param name="labelColumnName">The label column (for evaluation).</param>
515-
/// <param name="samplingKeyColumnName">Name of a column to use for grouping rows. If two examples share the same value of the <paramref name="samplingKeyColumnName"/>,
515+
/// <param name="partitionKeyColumnName">Name of a column to use for grouping rows. If two examples share the same value of the <paramref name="partitionKeyColumnName"/>,
516516
/// they are guaranteed to appear in the same subset (train or test). This can be used to ensure no label leakage from the train to the test set.
517517
/// If <see langword="null"/> no row grouping will be performed.</param>
518518
/// <param name="seed">Seed for the random number generator used to select rows for cross-validation folds.</param>
519519
/// <returns>Per-fold results: metrics, models, scored datasets.</returns>
520520
/// <returns>Per-fold results: metrics, models, scored datasets.</returns>
521521
public CrossValidationResult<MulticlassClassificationMetrics>[] CrossValidate(
522522
IDataView data, IEstimator<ITransformer> estimator, int numberOfFolds = 5, string labelColumnName = DefaultColumnNames.Label,
523-
string samplingKeyColumnName = null, int? seed = null)
523+
string partitionKeyColumnName = null, int? seed = null)
524524
{
525525
Environment.CheckNonEmpty(labelColumnName, nameof(labelColumnName));
526-
var result = CrossValidateTrain(data, estimator, numberOfFolds, samplingKeyColumnName, seed);
526+
var result = CrossValidateTrain(data, estimator, numberOfFolds, partitionKeyColumnName, seed);
527527
return result.Select(x => new CrossValidationResult<MulticlassClassificationMetrics>(x.Model,
528528
Evaluate(x.Scores, labelColumnName), x.Scores, x.Fold)).ToArray();
529529
}
@@ -572,24 +572,24 @@ public RegressionMetrics Evaluate(IDataView data, string labelColumnName = Defau
572572

573573
/// <summary>
574574
/// Run cross-validation over <paramref name="numberOfFolds"/> folds of <paramref name="data"/>, by fitting <paramref name="estimator"/>,
575-
/// and respecting <paramref name="samplingKeyColumnName"/> if provided.
575+
/// and respecting <paramref name="partitionKeyColumnName"/> if provided.
576576
/// Then evaluate each sub-model against <paramref name="labelColumnName"/> and return metrics.
577577
/// </summary>
578578
/// <param name="data">The data to run cross-validation on.</param>
579579
/// <param name="estimator">The estimator to fit.</param>
580580
/// <param name="numberOfFolds">Number of cross-validation folds.</param>
581581
/// <param name="labelColumnName">The label column (for evaluation).</param>
582-
/// <param name="samplingKeyColumnName">Name of a column to use for grouping rows. If two examples share the same value of the <paramref name="samplingKeyColumnName"/>,
582+
/// <param name="partitionKeyColumnName">Name of a column to use for grouping rows. If two examples share the same value of the <paramref name="partitionKeyColumnName"/>,
583583
/// they are guaranteed to appear in the same subset (train or test). This can be used to ensure no label leakage from the train to the test set.
584584
/// If <see langword="null"/> no row grouping will be performed.</param>
585585
/// <param name="seed">Seed for the random number generator used to select rows for cross-validation folds.</param>
586586
/// <returns>Per-fold results: metrics, models, scored datasets.</returns>
587587
public CrossValidationResult<RegressionMetrics>[] CrossValidate(
588588
IDataView data, IEstimator<ITransformer> estimator, int numberOfFolds = 5, string labelColumnName = DefaultColumnNames.Label,
589-
string samplingKeyColumnName = null, int? seed = null)
589+
string partitionKeyColumnName = null, int? seed = null)
590590
{
591591
Environment.CheckNonEmpty(labelColumnName, nameof(labelColumnName));
592-
var result = CrossValidateTrain(data, estimator, numberOfFolds, samplingKeyColumnName, seed);
592+
var result = CrossValidateTrain(data, estimator, numberOfFolds, partitionKeyColumnName, seed);
593593
return result.Select(x => new CrossValidationResult<RegressionMetrics>(x.Model,
594594
Evaluate(x.Scores, labelColumnName), x.Scores, x.Fold)).ToArray();
595595
}
@@ -673,18 +673,18 @@ internal AnomalyDetectionTrainers(AnomalyDetectionCatalog catalog)
673673
/// <param name="labelColumnName">The name of the label column in <paramref name="data"/>.</param>
674674
/// <param name="scoreColumnName">The name of the score column in <paramref name="data"/>.</param>
675675
/// <param name="predictedLabelColumnName">The name of the predicted label column in <paramref name="data"/>.</param>
676-
/// <param name="k">The number of false positives to compute the <see cref="AnomalyDetectionMetrics.DetectionRateAtKFalsePositives"/> metric. </param>
676+
/// <param name="falsePositiveCount">The number of false positives to compute the <see cref="AnomalyDetectionMetrics.DetectionRateAtKFalsePositives"/> metric. </param>
677677
/// <returns>Evaluation results.</returns>
678678
public AnomalyDetectionMetrics Evaluate(IDataView data, string labelColumnName = DefaultColumnNames.Label, string scoreColumnName = DefaultColumnNames.Score,
679-
string predictedLabelColumnName = DefaultColumnNames.PredictedLabel, int k = 10)
679+
string predictedLabelColumnName = DefaultColumnNames.PredictedLabel, int falsePositiveCount = 10)
680680
{
681681
Environment.CheckValue(data, nameof(data));
682682
Environment.CheckNonEmpty(labelColumnName, nameof(labelColumnName));
683683
Environment.CheckNonEmpty(scoreColumnName, nameof(scoreColumnName));
684684
Environment.CheckNonEmpty(predictedLabelColumnName, nameof(predictedLabelColumnName));
685685

686686
var args = new AnomalyDetectionEvaluator.Arguments();
687-
args.K = k;
687+
args.K = falsePositiveCount;
688688

689689
var eval = new AnomalyDetectionEvaluator(Environment, args);
690690
return eval.Evaluate(data, labelColumnName, scoreColumnName, predictedLabelColumnName);

test/Microsoft.ML.Tests/AnomalyDetectionTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public void RandomizedPcaTrainerBaselineTest()
3030
var transformedData = DetectAnomalyInMnistOneClass(trainPath, testPath);
3131

3232
// Evaluate
33-
var metrics = ML.AnomalyDetection.Evaluate(transformedData, k: 5);
33+
var metrics = ML.AnomalyDetection.Evaluate(transformedData, falsePositiveCount: 5);
3434

3535
Assert.Equal(0.98667, metrics.AreaUnderRocCurve, 5);
3636
Assert.Equal(0.90000, metrics.DetectionRateAtKFalsePositives, 5);

test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ public void TrainAndPredictIrisModelTest()
8484

8585
// Evaluate the trained pipeline
8686
var predicted = trainedModel.Transform(testData);
87-
var metrics = mlContext.MulticlassClassification.Evaluate(predicted, topK: 3);
87+
var metrics = mlContext.MulticlassClassification.Evaluate(predicted, topPredictionCount: 3);
8888

8989
Assert.Equal(.98, metrics.MacroAccuracy);
9090
Assert.Equal(.98, metrics.MicroAccuracy, 2);

test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ public void TrainAndPredictIrisModelWithStringLabelTest()
8787

8888
// Evaluate the trained pipeline
8989
var predicted = trainedModel.Transform(testData);
90-
var metrics = mlContext.MulticlassClassification.Evaluate(predicted, topK: 3);
90+
var metrics = mlContext.MulticlassClassification.Evaluate(predicted, topPredictionCount: 3);
9191

9292
Assert.Equal(.98, metrics.MacroAccuracy);
9393
Assert.Equal(.98, metrics.MicroAccuracy, 2);

test/Microsoft.ML.Tests/TrainerEstimators/SdcaTests.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ public void SdcaMulticlassLogisticRegression()
158158

159159
// Step 4: Make prediction and evaluate its quality (on training set).
160160
var prediction = model.Transform(data);
161-
var metrics = mlContext.MulticlassClassification.Evaluate(prediction, labelColumnName: "LabelIndex", topK: 1);
161+
var metrics = mlContext.MulticlassClassification.Evaluate(prediction, labelColumnName: "LabelIndex", topPredictionCount: 1);
162162

163163
// Check a few metrics to make sure the trained model is ok.
164164
Assert.InRange(metrics.TopKAccuracy, 0.8, 1);
@@ -192,7 +192,7 @@ public void SdcaMulticlassSupportVectorMachine()
192192

193193
// Step 4: Make prediction and evaluate its quality (on training set).
194194
var prediction = model.Transform(data);
195-
var metrics = mlContext.MulticlassClassification.Evaluate(prediction, labelColumnName: "LabelIndex", topK: 1);
195+
var metrics = mlContext.MulticlassClassification.Evaluate(prediction, labelColumnName: "LabelIndex", topPredictionCount: 1);
196196

197197
// Check a few metrics to make sure the trained model is ok.
198198
Assert.InRange(metrics.TopKAccuracy, 0.8, 1);

0 commit comments

Comments
 (0)