@@ -1110,20 +1110,22 @@ private static VersionInfo GetVersionInfo()
1110
1110
public readonly float Min ;
1111
1111
1112
1112
/// <summary> The value of probability in each bin.</summary>
1113
- public readonly float [ ] BinProbs ;
1113
+ public IReadOnlyList < float > BinProbs => _binProbs ;
1114
+
1115
+ private readonly float [ ] _binProbs ;
1114
1116
1115
1117
/// <summary> Initializes a new instance of <see cref="NaiveCalibrator"/>.</summary>
1116
1118
/// <param name="env">The <see cref="IHostEnvironment"/> to use.</param>
1117
1119
/// <param name="min">The minimum value in the first bin.</param>
1118
1120
/// <param name="binProbs">The values of the probability in each bin.</param>
1119
1121
/// <param name="binSize">The bin size.</param>
1120
- public NaiveCalibrator ( IHostEnvironment env , float min , float binSize , float [ ] binProbs )
1122
+ internal NaiveCalibrator ( IHostEnvironment env , float min , float binSize , float [ ] binProbs )
1121
1123
{
1122
1124
Contracts . CheckValue ( env , nameof ( env ) ) ;
1123
1125
_host = env . Register ( RegistrationName ) ;
1124
1126
Min = min ;
1125
1127
BinSize = binSize ;
1126
- BinProbs = binProbs ;
1128
+ _binProbs = binProbs ;
1127
1129
}
1128
1130
1129
1131
private NaiveCalibrator ( IHostEnvironment env , ModelLoadContext ctx )
@@ -1147,9 +1149,9 @@ private NaiveCalibrator(IHostEnvironment env, ModelLoadContext ctx)
1147
1149
Min = ctx . Reader . ReadFloat ( ) ;
1148
1150
_host . CheckDecode ( FloatUtils . IsFinite ( Min ) ) ;
1149
1151
1150
- BinProbs = ctx . Reader . ReadFloatArray ( ) ;
1151
- _host . CheckDecode ( Utils . Size ( BinProbs ) > 0 ) ;
1152
- _host . CheckDecode ( BinProbs . All ( x => ( 0 <= x && x <= 1 ) ) ) ;
1152
+ _binProbs = ctx . Reader . ReadFloatArray ( ) ;
1153
+ _host . CheckDecode ( Utils . Size ( _binProbs ) > 0 ) ;
1154
+ _host . CheckDecode ( _binProbs . All ( x => ( 0 <= x && x <= 1 ) ) ) ;
1153
1155
}
1154
1156
1155
1157
private static NaiveCalibrator Create ( IHostEnvironment env , ModelLoadContext ctx )
@@ -1180,7 +1182,7 @@ private void SaveCore(ModelSaveContext ctx)
1180
1182
ctx . Writer . Write ( sizeof ( float ) ) ;
1181
1183
ctx . Writer . Write ( BinSize ) ;
1182
1184
ctx . Writer . Write ( Min ) ;
1183
- ctx . Writer . WriteSingleArray ( BinProbs ) ;
1185
+ ctx . Writer . WriteSingleArray ( _binProbs ) ;
1184
1186
}
1185
1187
1186
1188
/// <summary>
@@ -1190,8 +1192,8 @@ public float PredictProbability(float output)
1190
1192
{
1191
1193
if ( float . IsNaN ( output ) )
1192
1194
return output ;
1193
- int binIdx = GetBinIdx ( output , Min , BinSize , BinProbs . Length ) ;
1194
- return BinProbs [ binIdx ] ;
1195
+ int binIdx = GetBinIdx ( output , Min , BinSize , _binProbs . Length ) ;
1196
+ return _binProbs [ binIdx ] ;
1195
1197
}
1196
1198
1197
1199
// get the bin for a given output
@@ -1205,11 +1207,6 @@ internal static int GetBinIdx(float output, float min, float binSize, int numBin
1205
1207
return binIdx ;
1206
1208
}
1207
1209
1208
- /// <summary> Get the summary of current calibrator settings </summary>
1209
- public string GetSummary ( )
1210
- {
1211
- return string . Format ( "Naive Calibrator has {0} bins, starting at {1}, with bin size of {2}" , BinProbs . Length , Min , BinSize ) ;
1212
- }
1213
1210
}
1214
1211
1215
1212
/// <summary>
@@ -1218,8 +1215,91 @@ public string GetSummary()
1218
1215
[ BestFriend ]
1219
1216
internal abstract class CalibratorTrainerBase : ICalibratorTrainer
1220
1217
{
1218
+ public sealed class DataStore : IEnumerable < DataStore . DataItem >
1219
+ {
1220
+ public readonly struct DataItem
1221
+ {
1222
+ // The actual binary label of this example.
1223
+ public readonly bool Target ;
1224
+ // The weight associated with this example.
1225
+ public readonly float Weight ;
1226
+ // The output of the example.
1227
+ public readonly float Score ;
1228
+
1229
+ public DataItem ( bool target , float weight , float score )
1230
+ {
1231
+ Target = target ;
1232
+ Weight = weight ;
1233
+ Score = score ;
1234
+ }
1235
+ }
1236
+
1237
+ // REVIEW: Should probably be a long.
1238
+ private int _itemsSeen ;
1239
+ private readonly Random _random ;
1240
+
1241
+ private static int _randSeed ;
1242
+
1243
+ private readonly int _capacity ;
1244
+ private DataItem [ ] _data ;
1245
+ private bool _dataSorted ;
1246
+
1247
+ public DataStore ( )
1248
+ : this ( 1000000 )
1249
+ {
1250
+ }
1251
+
1252
+ public DataStore ( int capacity )
1253
+ {
1254
+ Contracts . CheckParam ( capacity > 0 , nameof ( capacity ) , "must be positive" ) ;
1255
+
1256
+ _capacity = capacity ;
1257
+ _data = new DataItem [ Math . Min ( 4 , capacity ) ] ;
1258
+ // REVIEW: Horrifying. At a point when we have the IHost stuff plumbed through
1259
+ // calibrator training and also have the appetite to change a bunch of baselines, this
1260
+ // should be seeded using the host random.
1261
+ _random = new System . Random ( System . Threading . Interlocked . Increment ( ref _randSeed ) - 1 ) ;
1262
+ }
1263
+
1264
+ /// <summary>
1265
+ /// An enumerator over the <see cref="DataItem"/> entries sorted by score.
1266
+ /// </summary>
1267
+ /// <returns></returns>
1268
+ public IEnumerator < DataItem > GetEnumerator ( )
1269
+ {
1270
+ if ( ! _dataSorted )
1271
+ {
1272
+ var comp = Comparer < DataItem > . Create ( ( x , y ) => x . Score . CompareTo ( y . Score ) ) ;
1273
+ Array . Sort ( _data , 0 , Math . Min ( _itemsSeen , _capacity ) , comp ) ;
1274
+ _dataSorted = true ;
1275
+ }
1276
+ return _data . Take ( _itemsSeen ) . GetEnumerator ( ) ;
1277
+ }
1278
+
1279
+ IEnumerator IEnumerable . GetEnumerator ( )
1280
+ {
1281
+ return GetEnumerator ( ) ;
1282
+ }
1283
+
1284
+ public void AddToStore ( float score , bool isPositive , float weight )
1285
+ {
1286
+ // Can't calibrate NaN scores.
1287
+ if ( weight == 0 || float . IsNaN ( score ) )
1288
+ return ;
1289
+ int index = _itemsSeen ++ ;
1290
+ if ( _itemsSeen <= _capacity )
1291
+ Utils . EnsureSize ( ref _data , _itemsSeen , _capacity ) ;
1292
+ else
1293
+ {
1294
+ index = _random . Next ( _itemsSeen ) ; // 0 to items_seen - 1.
1295
+ if ( index >= _capacity ) // Don't keep it.
1296
+ return ;
1297
+ }
1298
+ _data [ index ] = new DataItem ( isPositive , weight , score ) ;
1299
+ }
1300
+ }
1221
1301
protected readonly IHost Host ;
1222
- protected CalibrationDataStore Data ;
1302
+ protected DataStore Data ;
1223
1303
protected const int DefaultMaxNumSamples = 1000000 ;
1224
1304
protected int MaxNumSamples ;
1225
1305
@@ -1239,7 +1319,7 @@ protected CalibratorTrainerBase(IHostEnvironment env, string name)
1239
1319
bool ICalibratorTrainer . ProcessTrainingExample ( float output , bool labelIs1 , float weight )
1240
1320
{
1241
1321
if ( Data == null )
1242
- Data = new CalibrationDataStore ( MaxNumSamples ) ;
1322
+ Data = new DataStore ( MaxNumSamples ) ;
1243
1323
Data . AddToStore ( output , labelIs1 , weight ) ;
1244
1324
return true ;
1245
1325
}
@@ -1485,15 +1565,21 @@ private static VersionInfo GetVersionInfo()
1485
1565
1486
1566
private readonly IHost _host ;
1487
1567
1568
+ /// <summary>
1569
+ /// Slope value for this calibrator.
1570
+ /// </summary>
1488
1571
public Double Slope { get ; }
1572
+ /// <summary>
1573
+ /// Offset value for this calibrator
1574
+ /// </summary>
1489
1575
public Double Offset { get ; }
1490
1576
bool ICanSavePfa . CanSavePfa => true ;
1491
1577
bool ICanSaveOnnx . CanSaveOnnx ( OnnxContext ctx ) => true ;
1492
1578
1493
1579
/// <summary>
1494
1580
/// Initializes a new instance of <see cref="PlattCalibrator"/>.
1495
1581
/// </summary>
1496
- public PlattCalibrator ( IHostEnvironment env , Double slope , Double offset )
1582
+ internal PlattCalibrator ( IHostEnvironment env , Double slope , Double offset )
1497
1583
{
1498
1584
Contracts . CheckValue ( env , nameof ( env ) ) ;
1499
1585
_host = env . Register ( RegistrationName ) ;
@@ -1556,14 +1642,15 @@ private void SaveCore(ModelSaveContext ctx)
1556
1642
}
1557
1643
}
1558
1644
1645
+ /// <summary> Given a classifier output, produce the probability.</summary>
1559
1646
public float PredictProbability ( float output )
1560
1647
{
1561
1648
if ( float . IsNaN ( output ) )
1562
1649
return output ;
1563
1650
return PredictProbability ( output , Slope , Offset ) ;
1564
1651
}
1565
1652
1566
- public static float PredictProbability ( float output , Double a , Double b )
1653
+ internal static float PredictProbability ( float output , Double a , Double b )
1567
1654
{
1568
1655
return ( float ) ( 1 / ( 1 + Math . Exp ( a * output + b ) ) ) ;
1569
1656
}
@@ -1597,11 +1684,6 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] scoreProbablityColu
1597
1684
return true ;
1598
1685
}
1599
1686
1600
- public string GetSummary ( )
1601
- {
1602
- return string . Format ( "Platt calibrator parameters: A={0}, B={1}" , Slope , Offset ) ;
1603
- }
1604
-
1605
1687
IParameterMixer IParameterMixer . CombineParameters ( IList < IParameterMixer > calibrators )
1606
1688
{
1607
1689
Double a = 0 ;
@@ -1703,12 +1785,17 @@ public override ICalibrator CreateCalibrator(IChannel ch)
1703
1785
1704
1786
/// <summary>
1705
1787
/// The pair-adjacent violators calibrator.
1706
- /// The function that is implemented by this calibrator is:
1707
- /// f(x) = v_i, if minX_i <= x <= maxX_i
1708
- /// = linear interpolate between v_i and v_i+1, if maxX_i < x < minX_i+1
1709
- /// = v_0, if x < minX_0
1710
- /// = v_n, if x > maxX_n
1711
1788
/// </summary>
1789
+ /// <remarks>
1790
+ /// The function that is implemented by this calibrator is:
1791
+ /// P(x) =
1792
+ /// <list type="bullet">
1793
+ /// <item><description><see cref="Values"/>[i], if <see cref="Mins"/>[i] <= x <= <see cref="Maxes"/>[i]</description>></item>
1794
+ /// <item> <description>Linear interpolation between <see cref="Values"/>[i] and <see cref="Values"/>[i+1], if <see cref="Maxes"/>[i] < x < <see cref="Mins"/>[i+1]</description></item>
1795
+ /// <item><description><see cref="Values"/>[0], if x < <see cref="Mins"/>[0]</description></item>
1796
+ /// <item><description><see cref="Values"/>[n], if x > <see cref="Maxes"/>[n]</description></item>
1797
+ ///</list>
1798
+ /// </remarks>
1712
1799
public sealed class PavCalibrator : ICalibrator , ICanSaveInBinaryFormat
1713
1800
{
1714
1801
internal const string LoaderSignature = "PAVCaliExec" ;
@@ -1731,8 +1818,17 @@ private static VersionInfo GetVersionInfo()
1731
1818
private const float MaxToReturn = 1 - Epsilon ; // max predicted is 1 - min;
1732
1819
1733
1820
private readonly IHost _host ;
1821
+ /// <summary>
1822
+ /// Bottom borders of PAV intervals.
1823
+ /// </summary>
1734
1824
public readonly ImmutableArray < float > Mins ;
1825
+ /// <summary>
1826
+ /// Upper borders of PAV intervals.
1827
+ /// </summary>
1735
1828
public readonly ImmutableArray < float > Maxes ;
1829
+ /// <summary>
1830
+ /// Values of PAV intervals.
1831
+ /// </summary>
1736
1832
public readonly ImmutableArray < float > Values ;
1737
1833
1738
1834
/// <summary>
@@ -1742,7 +1838,7 @@ private static VersionInfo GetVersionInfo()
1742
1838
/// <param name="mins">The minimum values for each piece.</param>
1743
1839
/// <param name="maxes">The maximum values for each piece.</param>
1744
1840
/// <param name="values">The actual values for each piece.</param>
1745
- public PavCalibrator ( IHostEnvironment env , ImmutableArray < float > mins , ImmutableArray < float > maxes , ImmutableArray < float > values )
1841
+ internal PavCalibrator ( IHostEnvironment env , ImmutableArray < float > mins , ImmutableArray < float > maxes , ImmutableArray < float > values )
1746
1842
{
1747
1843
Contracts . AssertValue ( env ) ;
1748
1844
_host = env . Register ( RegistrationName ) ;
@@ -1851,6 +1947,7 @@ private void SaveCore(ModelSaveContext ctx)
1851
1947
_host . CheckDecode ( valuePrev <= 1 ) ;
1852
1948
}
1853
1949
1950
+ /// <summary> Given a classifier output, produce the probability.</summary>
1854
1951
public float PredictProbability ( float output )
1855
1952
{
1856
1953
if ( float . IsNaN ( output ) )
@@ -1890,95 +1987,6 @@ private float FindValue(float score)
1890
1987
float t = ( score - Maxes [ pos - 1 ] ) / ( Mins [ pos ] - Maxes [ pos - 1 ] ) ;
1891
1988
return Values [ pos - 1 ] + t * ( Values [ pos ] - Values [ pos - 1 ] ) ;
1892
1989
}
1893
-
1894
- public string GetSummary ( )
1895
- {
1896
- return string . Format ( "PAV calibrator with {0} intervals" , Mins . Length ) ;
1897
- }
1898
- }
1899
-
1900
- public sealed class CalibrationDataStore : IEnumerable < CalibrationDataStore . DataItem >
1901
- {
1902
- public readonly struct DataItem
1903
- {
1904
- // The actual binary label of this example.
1905
- public readonly bool Target ;
1906
- // The weight associated with this example.
1907
- public readonly float Weight ;
1908
- // The output of the example.
1909
- public readonly float Score ;
1910
-
1911
- public DataItem ( bool target , float weight , float score )
1912
- {
1913
- Target = target ;
1914
- Weight = weight ;
1915
- Score = score ;
1916
- }
1917
- }
1918
-
1919
- // REVIEW: Should probably be a long.
1920
- private int _itemsSeen ;
1921
- private readonly Random _random ;
1922
-
1923
- private static int _randSeed ;
1924
-
1925
- private readonly int _capacity ;
1926
- private DataItem [ ] _data ;
1927
- private bool _dataSorted ;
1928
-
1929
- public CalibrationDataStore ( )
1930
- : this ( 1000000 )
1931
- {
1932
- }
1933
-
1934
- public CalibrationDataStore ( int capacity )
1935
- {
1936
- Contracts . CheckParam ( capacity > 0 , nameof ( capacity ) , "must be positive" ) ;
1937
-
1938
- _capacity = capacity ;
1939
- _data = new DataItem [ Math . Min ( 4 , capacity ) ] ;
1940
- // REVIEW: Horrifying. At a point when we have the IHost stuff plumbed through
1941
- // calibrator training and also have the appetite to change a bunch of baselines, this
1942
- // should be seeded using the host random.
1943
- _random = new System . Random ( System . Threading . Interlocked . Increment ( ref _randSeed ) - 1 ) ;
1944
- }
1945
-
1946
- /// <summary>
1947
- /// An enumerator over the <see cref="DataItem"/> entries sorted by score.
1948
- /// </summary>
1949
- /// <returns></returns>
1950
- public IEnumerator < DataItem > GetEnumerator ( )
1951
- {
1952
- if ( ! _dataSorted )
1953
- {
1954
- var comp = Comparer < DataItem > . Create ( ( x , y ) => x . Score . CompareTo ( y . Score ) ) ;
1955
- Array . Sort ( _data , 0 , Math . Min ( _itemsSeen , _capacity ) , comp ) ;
1956
- _dataSorted = true ;
1957
- }
1958
- return _data . Take ( _itemsSeen ) . GetEnumerator ( ) ;
1959
- }
1960
-
1961
- IEnumerator IEnumerable . GetEnumerator ( )
1962
- {
1963
- return GetEnumerator ( ) ;
1964
- }
1965
-
1966
- public void AddToStore ( float score , bool isPositive , float weight )
1967
- {
1968
- // Can't calibrate NaN scores.
1969
- if ( weight == 0 || float . IsNaN ( score ) )
1970
- return ;
1971
- int index = _itemsSeen ++ ;
1972
- if ( _itemsSeen <= _capacity )
1973
- Utils . EnsureSize ( ref _data , _itemsSeen , _capacity ) ;
1974
- else
1975
- {
1976
- index = _random . Next ( _itemsSeen ) ; // 0 to items_seen - 1.
1977
- if ( index >= _capacity ) // Don't keep it.
1978
- return ;
1979
- }
1980
- _data [ index ] = new DataItem ( isPositive , weight , score ) ;
1981
- }
1982
1990
}
1983
1991
1984
1992
internal static class Calibrate
0 commit comments