Skip to content

Commit f00b206

Browse files
committed
Update calibrator estimators to be more suitable.
* Internalize infrastructure only interface ICalibratorTrainer. * Update calibrator estimators so they are a suitable replacement for calibrator trainers in the public surface, e.g., no longer take IPredictor.
1 parent 5c442a9 commit f00b206

File tree

13 files changed

+276
-229
lines changed

13 files changed

+276
-229
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
using System;
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
26
using System.Linq;
37
using Microsoft.ML.Calibrator;
48
using Microsoft.ML.Data;
@@ -75,7 +79,7 @@ public static void Calibration()
7579

7680
// Let's train a calibrator estimator on this scored dataset. The trained calibrator estimator produces a transformer
7781
// that can transform the scored data by adding a new column names "Probability".
78-
var calibratorEstimator = new PlattCalibratorEstimator(mlContext, model, "Sentiment", "Features");
82+
var calibratorEstimator = new PlattCalibratorEstimator(mlContext, "Sentiment", "Score");
7983
var calibratorTransformer = calibratorEstimator.Fit(scoredData);
8084

8185
// Transform the scored data with a calibrator transfomer by adding a new column names "Probability".

src/Microsoft.ML.Data/Prediction/Calibrator.cs

+96-57
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
using Microsoft.ML;
1313
using Microsoft.ML.Calibrator;
1414
using Microsoft.ML.CommandLine;
15+
using Microsoft.ML.Core.Data;
1516
using Microsoft.ML.Data;
1617
using Microsoft.ML.EntryPoints;
1718
using Microsoft.ML.Internal.Calibration;
@@ -85,14 +86,24 @@ namespace Microsoft.ML.Internal.Calibration
8586
/// <summary>
8687
/// Signature for the loaders of calibrators.
8788
/// </summary>
88-
public delegate void SignatureCalibrator();
89+
[BestFriend]
90+
internal delegate void SignatureCalibrator();
8991

92+
[BestFriend]
9093
[TlcModule.ComponentKind("CalibratorTrainer")]
91-
public interface ICalibratorTrainerFactory : IComponentFactory<ICalibratorTrainer>
94+
internal interface ICalibratorTrainerFactory : IComponentFactory<ICalibratorTrainer>
9295
{
9396
}
9497

95-
public interface ICalibratorTrainer
98+
/// <summary>
99+
/// This is a legacy interface still used for the command line and entry-points. All applications should transition away
100+
/// from this interface and still work instead via <see cref="IEstimator{TTransformer}"/> of <see cref="CalibratorTransformer{TICalibrator}"/>,
101+
/// for example, the subclasses of <see cref="CalibratorEstimatorBase{TICalibrator}"/>. However for now we retain this
102+
/// until such time as those components making use of it can transition to the new way. No public surface should use
103+
/// this, and even new internal code should avoid its use if possible.
104+
/// </summary>
105+
[BestFriend]
106+
internal interface ICalibratorTrainer
96107
{
97108
/// <summary>
98109
/// True if the calibrator needs training, false otherwise.
@@ -107,6 +118,17 @@ public interface ICalibratorTrainer
107118
ICalibrator FinishTraining(IChannel ch);
108119
}
109120

121+
/// <summary>
122+
/// This is a shim interface implemented only by <see cref="CalibratorEstimatorBase{TICalibrator}"/> to enable
123+
/// access to the underlying legacy <see cref="ICalibratorTrainer"/> interface for those components that use
124+
/// that old mechanism that we do not care to change right now.
125+
/// </summary>
126+
[BestFriend]
127+
internal interface IHaveCalibratorTrainer
128+
{
129+
ICalibratorTrainer CalibratorTrainer { get; }
130+
}
131+
110132
/// <summary>
111133
/// An interface for predictors that take care of their own calibration given an input data view.
112134
/// </summary>
@@ -842,6 +864,64 @@ public static IPredictor GetCalibratedPredictor(IHostEnvironment env, IChannel c
842864
return CreateCalibratedPredictor(env, (IPredictorProducing<float>)predictor, trainedCalibrator);
843865
}
844866

867+
public static ICalibrator TrainCalibrator(IHostEnvironment env, IChannel ch, ICalibratorTrainer caliTrainer, IDataView scored, string labelColumn, string scoreColumn, string weightColumn = null, int maxRows = _maxCalibrationExamples)
868+
{
869+
Contracts.CheckValue(env, nameof(env));
870+
env.CheckValue(ch, nameof(ch));
871+
ch.CheckValue(scored, nameof(scored));
872+
ch.CheckValue(caliTrainer, nameof(caliTrainer));
873+
ch.CheckParam(!caliTrainer.NeedsTraining || !string.IsNullOrWhiteSpace(labelColumn), nameof(labelColumn),
874+
"If " + nameof(caliTrainer) + " requires training, then " + nameof(labelColumn) + " must have a value.");
875+
ch.CheckNonWhiteSpace(scoreColumn, nameof(scoreColumn));
876+
877+
if (!caliTrainer.NeedsTraining)
878+
return caliTrainer.FinishTraining(ch);
879+
880+
var labelCol = scored.Schema[labelColumn];
881+
var scoreCol = scored.Schema[scoreColumn];
882+
883+
var weightCol = weightColumn == null ? null : scored.Schema.GetColumnOrNull(weightColumn);
884+
if (weightColumn != null && !weightCol.HasValue)
885+
throw ch.ExceptSchemaMismatch(nameof(weightColumn), "weight", weightColumn);
886+
887+
ch.Info("Training calibrator.");
888+
889+
var cols = weightCol.HasValue ?
890+
new Schema.Column[] { labelCol, scoreCol, weightCol.Value } :
891+
new Schema.Column[] { labelCol, scoreCol };
892+
893+
using (var cursor = scored.GetRowCursor(cols))
894+
{
895+
var labelGetter = RowCursorUtils.GetLabelGetter(cursor, labelCol.Index);
896+
var scoreGetter = RowCursorUtils.GetGetterAs<Single>(NumberType.R4, cursor, scoreCol.Index);
897+
ValueGetter<Single> weightGetter = !weightCol.HasValue ? (ref float dst) => dst = 1 :
898+
RowCursorUtils.GetGetterAs<Single>(NumberType.R4, cursor, weightCol.Value.Index);
899+
900+
int num = 0;
901+
while (cursor.MoveNext())
902+
{
903+
Single label = 0;
904+
labelGetter(ref label);
905+
if (!FloatUtils.IsFinite(label))
906+
continue;
907+
Single score = 0;
908+
scoreGetter(ref score);
909+
if (!FloatUtils.IsFinite(score))
910+
continue;
911+
Single weight = 0;
912+
weightGetter(ref weight);
913+
if (!FloatUtils.IsFinite(weight))
914+
continue;
915+
916+
caliTrainer.ProcessTrainingExample(score, label > 0, weight);
917+
918+
if (maxRows > 0 && ++num >= maxRows)
919+
break;
920+
}
921+
}
922+
return caliTrainer.FinishTraining(ch);
923+
}
924+
845925
/// <summary>
846926
/// Trains a calibrator.
847927
/// </summary>
@@ -857,60 +937,14 @@ public static ICalibrator TrainCalibrator(IHostEnvironment env, IChannel ch, ICa
857937
{
858938
Contracts.CheckValue(env, nameof(env));
859939
env.CheckValue(ch, nameof(ch));
940+
ch.CheckValue(caliTrainer, nameof(caliTrainer));
860941
ch.CheckValue(predictor, nameof(predictor));
861942
ch.CheckValue(data, nameof(data));
862943
ch.CheckParam(data.Schema.Label.HasValue, nameof(data), "data must have a Label column");
863944

864945
var scored = ScoreUtils.GetScorer(predictor, data, env, null);
865-
866-
if (caliTrainer.NeedsTraining)
867-
{
868-
int labelCol;
869-
if (!scored.Schema.TryGetColumnIndex(data.Schema.Label.Value.Name, out labelCol))
870-
throw ch.Except("No label column found");
871-
int scoreCol;
872-
if (!scored.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out scoreCol))
873-
throw ch.Except("No score column found");
874-
int weightCol = -1;
875-
if (data.Schema.Weight?.Name is string weightName && scored.Schema.GetColumnOrNull(weightName)?.Index is int weightIdx)
876-
weightCol = weightIdx;
877-
ch.Info("Training calibrator.");
878-
879-
var cols = weightCol > -1 ?
880-
new Schema.Column[] { scored.Schema[labelCol], scored.Schema[scoreCol], scored.Schema[weightCol] } :
881-
new Schema.Column[] { scored.Schema[labelCol], scored.Schema[scoreCol] };
882-
883-
using (var cursor = scored.GetRowCursor(cols))
884-
{
885-
var labelGetter = RowCursorUtils.GetLabelGetter(cursor, labelCol);
886-
var scoreGetter = RowCursorUtils.GetGetterAs<Single>(NumberType.R4, cursor, scoreCol);
887-
ValueGetter<Single> weightGetter = weightCol == -1 ? (ref float dst) => dst = 1 :
888-
RowCursorUtils.GetGetterAs<Single>(NumberType.R4, cursor, weightCol);
889-
890-
int num = 0;
891-
while (cursor.MoveNext())
892-
{
893-
Single label = 0;
894-
labelGetter(ref label);
895-
if (!FloatUtils.IsFinite(label))
896-
continue;
897-
Single score = 0;
898-
scoreGetter(ref score);
899-
if (!FloatUtils.IsFinite(score))
900-
continue;
901-
Single weight = 0;
902-
weightGetter(ref weight);
903-
if (!FloatUtils.IsFinite(weight))
904-
continue;
905-
906-
caliTrainer.ProcessTrainingExample(score, label > 0, weight);
907-
908-
if (maxRows > 0 && ++num >= maxRows)
909-
break;
910-
}
911-
}
912-
}
913-
return caliTrainer.FinishTraining(ch);
946+
var scoreColumn = scored.Schema[DefaultColumnNames.Score];
947+
return TrainCalibrator(env, ch, caliTrainer, scored, data.Schema.Label.Value.Name, DefaultColumnNames.Score, data.Schema.Weight?.Name, maxRows);
914948
}
915949

916950
public static IPredictorProducing<float> CreateCalibratedPredictor<TSubPredictor, TCalibrator>(IHostEnvironment env, TSubPredictor predictor, TCalibrator cali)
@@ -953,7 +987,8 @@ public ICalibratorTrainer CreateComponent(IHostEnvironment env)
953987
/// The probability of belonging to a particular class, for example class 1, is the number of class 1 instances in the bin, divided by the total number
954988
/// of instances in that bin.
955989
/// </summary>
956-
public sealed class NaiveCalibratorTrainer : ICalibratorTrainer
990+
[BestFriend]
991+
internal sealed class NaiveCalibratorTrainer : ICalibratorTrainer
957992
{
958993
private readonly IHost _host;
959994

@@ -1181,7 +1216,8 @@ public string GetSummary()
11811216
/// <summary>
11821217
/// Base class for calibrator trainers.
11831218
/// </summary>
1184-
public abstract class CalibratorTrainerBase : ICalibratorTrainer
1219+
[BestFriend]
1220+
internal abstract class CalibratorTrainerBase : ICalibratorTrainer
11851221
{
11861222
protected readonly IHost Host;
11871223
protected CalibrationDataStore Data;
@@ -1230,7 +1266,8 @@ public ICalibratorTrainer CreateComponent(IHostEnvironment env)
12301266
}
12311267
}
12321268

1233-
public sealed class PlattCalibratorTrainer : CalibratorTrainerBase
1269+
[BestFriend]
1270+
internal sealed class PlattCalibratorTrainer : CalibratorTrainerBase
12341271
{
12351272
internal const string UserName = "Sigmoid Calibration";
12361273
internal const string LoadName = "PlattCalibration";
@@ -1389,7 +1426,8 @@ public override ICalibrator CreateCalibrator(IChannel ch)
13891426
}
13901427
}
13911428

1392-
public sealed class FixedPlattCalibratorTrainer : ICalibratorTrainer
1429+
[BestFriend]
1430+
internal sealed class FixedPlattCalibratorTrainer : ICalibratorTrainer
13931431
{
13941432
[TlcModule.Component(Name = "FixedPlattCalibrator", FriendlyName = "Fixed Platt Calibrator", Aliases = new[] { "FixedPlatt", "FixedSigmoid" })]
13951433
public sealed class Arguments : ICalibratorTrainerFactory
@@ -1591,7 +1629,8 @@ public ICalibratorTrainer CreateComponent(IHostEnvironment env)
15911629
}
15921630
}
15931631

1594-
public class PavCalibratorTrainer : CalibratorTrainerBase
1632+
[BestFriend]
1633+
internal sealed class PavCalibratorTrainer : CalibratorTrainerBase
15951634
{
15961635
// a piece of the piecwise function
15971636
private readonly struct Piece

0 commit comments

Comments
 (0)