12
12
using Microsoft . ML ;
13
13
using Microsoft . ML . Calibrator ;
14
14
using Microsoft . ML . CommandLine ;
15
+ using Microsoft . ML . Core . Data ;
15
16
using Microsoft . ML . Data ;
16
17
using Microsoft . ML . EntryPoints ;
17
18
using Microsoft . ML . Internal . Calibration ;
@@ -85,14 +86,24 @@ namespace Microsoft.ML.Internal.Calibration
85
86
/// <summary>
86
87
/// Signature for the loaders of calibrators.
87
88
/// </summary>
88
- public delegate void SignatureCalibrator ( ) ;
89
+ [ BestFriend ]
90
+ internal delegate void SignatureCalibrator ( ) ;
89
91
92
+ [ BestFriend ]
90
93
[ TlcModule . ComponentKind ( "CalibratorTrainer" ) ]
91
- public interface ICalibratorTrainerFactory : IComponentFactory < ICalibratorTrainer >
94
+ internal interface ICalibratorTrainerFactory : IComponentFactory < ICalibratorTrainer >
92
95
{
93
96
}
94
97
95
- public interface ICalibratorTrainer
98
+ /// <summary>
99
+ /// This is a legacy interface still used for the command line and entry-points. All applications should transition away
100
+ /// from this interface and still work instead via <see cref="IEstimator{TTransformer}"/> of <see cref="CalibratorTransformer{TICalibrator}"/>,
101
+ /// for example, the subclasses of <see cref="CalibratorEstimatorBase{TICalibrator}"/>. However for now we retain this
102
+ /// until such time as those components making use of it can transition to the new way. No public surface should use
103
+ /// this, and even new internal code should avoid its use if possible.
104
+ /// </summary>
105
+ [ BestFriend ]
106
+ internal interface ICalibratorTrainer
96
107
{
97
108
/// <summary>
98
109
/// True if the calibrator needs training, false otherwise.
@@ -107,6 +118,17 @@ public interface ICalibratorTrainer
107
118
ICalibrator FinishTraining ( IChannel ch ) ;
108
119
}
109
120
121
+ /// <summary>
122
+ /// This is a shim interface implemented only by <see cref="CalibratorEstimatorBase{TICalibrator}"/> to enable
123
+ /// access to the underlying legacy <see cref="ICalibratorTrainer"/> interface for those components that use
124
+ /// that old mechanism that we do not care to change right now.
125
+ /// </summary>
126
+ [ BestFriend ]
127
+ internal interface IHaveCalibratorTrainer
128
+ {
129
+ ICalibratorTrainer CalibratorTrainer { get ; }
130
+ }
131
+
110
132
/// <summary>
111
133
/// An interface for predictors that take care of their own calibration given an input data view.
112
134
/// </summary>
@@ -842,6 +864,64 @@ public static IPredictor GetCalibratedPredictor(IHostEnvironment env, IChannel c
842
864
return CreateCalibratedPredictor ( env , ( IPredictorProducing < float > ) predictor , trainedCalibrator ) ;
843
865
}
844
866
867
+ public static ICalibrator TrainCalibrator ( IHostEnvironment env , IChannel ch , ICalibratorTrainer caliTrainer , IDataView scored , string labelColumn , string scoreColumn , string weightColumn = null , int maxRows = _maxCalibrationExamples )
868
+ {
869
+ Contracts . CheckValue ( env , nameof ( env ) ) ;
870
+ env . CheckValue ( ch , nameof ( ch ) ) ;
871
+ ch . CheckValue ( scored , nameof ( scored ) ) ;
872
+ ch . CheckValue ( caliTrainer , nameof ( caliTrainer ) ) ;
873
+ ch . CheckParam ( ! caliTrainer . NeedsTraining || ! string . IsNullOrWhiteSpace ( labelColumn ) , nameof ( labelColumn ) ,
874
+ "If " + nameof ( caliTrainer ) + " requires training, then " + nameof ( labelColumn ) + " must have a value." ) ;
875
+ ch . CheckNonWhiteSpace ( scoreColumn , nameof ( scoreColumn ) ) ;
876
+
877
+ if ( ! caliTrainer . NeedsTraining )
878
+ return caliTrainer . FinishTraining ( ch ) ;
879
+
880
+ var labelCol = scored . Schema [ labelColumn ] ;
881
+ var scoreCol = scored . Schema [ scoreColumn ] ;
882
+
883
+ var weightCol = weightColumn == null ? null : scored . Schema . GetColumnOrNull ( weightColumn ) ;
884
+ if ( weightColumn != null && ! weightCol . HasValue )
885
+ throw ch . ExceptSchemaMismatch ( nameof ( weightColumn ) , "weight" , weightColumn ) ;
886
+
887
+ ch . Info ( "Training calibrator." ) ;
888
+
889
+ var cols = weightCol . HasValue ?
890
+ new Schema . Column [ ] { labelCol , scoreCol , weightCol . Value } :
891
+ new Schema . Column [ ] { labelCol , scoreCol } ;
892
+
893
+ using ( var cursor = scored . GetRowCursor ( cols ) )
894
+ {
895
+ var labelGetter = RowCursorUtils . GetLabelGetter ( cursor , labelCol . Index ) ;
896
+ var scoreGetter = RowCursorUtils . GetGetterAs < Single > ( NumberType . R4 , cursor , scoreCol . Index ) ;
897
+ ValueGetter < Single > weightGetter = ! weightCol . HasValue ? ( ref float dst ) => dst = 1 :
898
+ RowCursorUtils . GetGetterAs < Single > ( NumberType . R4 , cursor , weightCol . Value . Index ) ;
899
+
900
+ int num = 0 ;
901
+ while ( cursor . MoveNext ( ) )
902
+ {
903
+ Single label = 0 ;
904
+ labelGetter ( ref label ) ;
905
+ if ( ! FloatUtils . IsFinite ( label ) )
906
+ continue ;
907
+ Single score = 0 ;
908
+ scoreGetter ( ref score ) ;
909
+ if ( ! FloatUtils . IsFinite ( score ) )
910
+ continue ;
911
+ Single weight = 0 ;
912
+ weightGetter ( ref weight ) ;
913
+ if ( ! FloatUtils . IsFinite ( weight ) )
914
+ continue ;
915
+
916
+ caliTrainer . ProcessTrainingExample ( score , label > 0 , weight ) ;
917
+
918
+ if ( maxRows > 0 && ++ num >= maxRows )
919
+ break ;
920
+ }
921
+ }
922
+ return caliTrainer . FinishTraining ( ch ) ;
923
+ }
924
+
845
925
/// <summary>
846
926
/// Trains a calibrator.
847
927
/// </summary>
@@ -857,60 +937,14 @@ public static ICalibrator TrainCalibrator(IHostEnvironment env, IChannel ch, ICa
857
937
{
858
938
Contracts . CheckValue ( env , nameof ( env ) ) ;
859
939
env . CheckValue ( ch , nameof ( ch ) ) ;
940
+ ch . CheckValue ( caliTrainer , nameof ( caliTrainer ) ) ;
860
941
ch . CheckValue ( predictor , nameof ( predictor ) ) ;
861
942
ch . CheckValue ( data , nameof ( data ) ) ;
862
943
ch . CheckParam ( data . Schema . Label . HasValue , nameof ( data ) , "data must have a Label column" ) ;
863
944
864
945
var scored = ScoreUtils . GetScorer ( predictor , data , env , null ) ;
865
-
866
- if ( caliTrainer . NeedsTraining )
867
- {
868
- int labelCol ;
869
- if ( ! scored . Schema . TryGetColumnIndex ( data . Schema . Label . Value . Name , out labelCol ) )
870
- throw ch . Except ( "No label column found" ) ;
871
- int scoreCol ;
872
- if ( ! scored . Schema . TryGetColumnIndex ( MetadataUtils . Const . ScoreValueKind . Score , out scoreCol ) )
873
- throw ch . Except ( "No score column found" ) ;
874
- int weightCol = - 1 ;
875
- if ( data . Schema . Weight ? . Name is string weightName && scored . Schema . GetColumnOrNull ( weightName ) ? . Index is int weightIdx )
876
- weightCol = weightIdx ;
877
- ch . Info ( "Training calibrator." ) ;
878
-
879
- var cols = weightCol > - 1 ?
880
- new Schema . Column [ ] { scored . Schema [ labelCol ] , scored . Schema [ scoreCol ] , scored . Schema [ weightCol ] } :
881
- new Schema . Column [ ] { scored . Schema [ labelCol ] , scored . Schema [ scoreCol ] } ;
882
-
883
- using ( var cursor = scored . GetRowCursor ( cols ) )
884
- {
885
- var labelGetter = RowCursorUtils . GetLabelGetter ( cursor , labelCol ) ;
886
- var scoreGetter = RowCursorUtils . GetGetterAs < Single > ( NumberType . R4 , cursor , scoreCol ) ;
887
- ValueGetter < Single > weightGetter = weightCol == - 1 ? ( ref float dst ) => dst = 1 :
888
- RowCursorUtils . GetGetterAs < Single > ( NumberType . R4 , cursor , weightCol ) ;
889
-
890
- int num = 0 ;
891
- while ( cursor . MoveNext ( ) )
892
- {
893
- Single label = 0 ;
894
- labelGetter ( ref label ) ;
895
- if ( ! FloatUtils . IsFinite ( label ) )
896
- continue ;
897
- Single score = 0 ;
898
- scoreGetter ( ref score ) ;
899
- if ( ! FloatUtils . IsFinite ( score ) )
900
- continue ;
901
- Single weight = 0 ;
902
- weightGetter ( ref weight ) ;
903
- if ( ! FloatUtils . IsFinite ( weight ) )
904
- continue ;
905
-
906
- caliTrainer . ProcessTrainingExample ( score , label > 0 , weight ) ;
907
-
908
- if ( maxRows > 0 && ++ num >= maxRows )
909
- break ;
910
- }
911
- }
912
- }
913
- return caliTrainer . FinishTraining ( ch ) ;
946
+ var scoreColumn = scored . Schema [ DefaultColumnNames . Score ] ;
947
+ return TrainCalibrator ( env , ch , caliTrainer , scored , data . Schema . Label . Value . Name , DefaultColumnNames . Score , data . Schema . Weight ? . Name , maxRows ) ;
914
948
}
915
949
916
950
public static IPredictorProducing < float > CreateCalibratedPredictor < TSubPredictor , TCalibrator > ( IHostEnvironment env , TSubPredictor predictor , TCalibrator cali )
@@ -953,7 +987,8 @@ public ICalibratorTrainer CreateComponent(IHostEnvironment env)
953
987
/// The probability of belonging to a particular class, for example class 1, is the number of class 1 instances in the bin, divided by the total number
954
988
/// of instances in that bin.
955
989
/// </summary>
956
- public sealed class NaiveCalibratorTrainer : ICalibratorTrainer
990
+ [ BestFriend ]
991
+ internal sealed class NaiveCalibratorTrainer : ICalibratorTrainer
957
992
{
958
993
private readonly IHost _host ;
959
994
@@ -1181,7 +1216,8 @@ public string GetSummary()
1181
1216
/// <summary>
1182
1217
/// Base class for calibrator trainers.
1183
1218
/// </summary>
1184
- public abstract class CalibratorTrainerBase : ICalibratorTrainer
1219
+ [ BestFriend ]
1220
+ internal abstract class CalibratorTrainerBase : ICalibratorTrainer
1185
1221
{
1186
1222
protected readonly IHost Host ;
1187
1223
protected CalibrationDataStore Data ;
@@ -1230,7 +1266,8 @@ public ICalibratorTrainer CreateComponent(IHostEnvironment env)
1230
1266
}
1231
1267
}
1232
1268
1233
- public sealed class PlattCalibratorTrainer : CalibratorTrainerBase
1269
+ [ BestFriend ]
1270
+ internal sealed class PlattCalibratorTrainer : CalibratorTrainerBase
1234
1271
{
1235
1272
internal const string UserName = "Sigmoid Calibration" ;
1236
1273
internal const string LoadName = "PlattCalibration" ;
@@ -1389,7 +1426,8 @@ public override ICalibrator CreateCalibrator(IChannel ch)
1389
1426
}
1390
1427
}
1391
1428
1392
- public sealed class FixedPlattCalibratorTrainer : ICalibratorTrainer
1429
+ [ BestFriend ]
1430
+ internal sealed class FixedPlattCalibratorTrainer : ICalibratorTrainer
1393
1431
{
1394
1432
[ TlcModule . Component ( Name = "FixedPlattCalibrator" , FriendlyName = "Fixed Platt Calibrator" , Aliases = new [ ] { "FixedPlatt" , "FixedSigmoid" } ) ]
1395
1433
public sealed class Arguments : ICalibratorTrainerFactory
@@ -1591,7 +1629,8 @@ public ICalibratorTrainer CreateComponent(IHostEnvironment env)
1591
1629
}
1592
1630
}
1593
1631
1594
- public class PavCalibratorTrainer : CalibratorTrainerBase
1632
+ [ BestFriend ]
1633
+ internal sealed class PavCalibratorTrainer : CalibratorTrainerBase
1595
1634
{
1596
1635
// a piece of the piecwise function
1597
1636
private readonly struct Piece
@@ -1664,6 +1703,7 @@ public override ICalibrator CreateCalibrator(IChannel ch)
1664
1703
}
1665
1704
1666
1705
/// <summary>
1706
+ /// The pair-adjacent violators calibrator.
1667
1707
/// The function that is implemented by this calibrator is:
1668
1708
/// f(x) = v_i, if minX_i <= x <= maxX_i
1669
1709
/// = linear interpolate between v_i and v_i+1, if maxX_i < x < minX_i+1
0 commit comments