Skip to content

Commit 25404b8

Browse files
authored
Cleanup calibrators (#2601)
1 parent a56caee commit 25404b8

File tree

3 files changed

+176
-158
lines changed

3 files changed

+176
-158
lines changed

src/Microsoft.ML.Data/Prediction/Calibrator.cs

+126-118
Original file line numberDiff line numberDiff line change
@@ -1110,20 +1110,22 @@ private static VersionInfo GetVersionInfo()
11101110
public readonly float Min;
11111111

11121112
/// <summary> The value of probability in each bin.</summary>
1113-
public readonly float[] BinProbs;
1113+
public IReadOnlyList<float> BinProbs => _binProbs;
1114+
1115+
private readonly float[] _binProbs;
11141116

11151117
/// <summary> Initializes a new instance of <see cref="NaiveCalibrator"/>.</summary>
11161118
/// <param name="env">The <see cref="IHostEnvironment"/> to use.</param>
11171119
/// <param name="min">The minimum value in the first bin.</param>
11181120
/// <param name="binProbs">The values of the probability in each bin.</param>
11191121
/// <param name="binSize">The bin size.</param>
1120-
public NaiveCalibrator(IHostEnvironment env, float min, float binSize, float[] binProbs)
1122+
internal NaiveCalibrator(IHostEnvironment env, float min, float binSize, float[] binProbs)
11211123
{
11221124
Contracts.CheckValue(env, nameof(env));
11231125
_host = env.Register(RegistrationName);
11241126
Min = min;
11251127
BinSize = binSize;
1126-
BinProbs = binProbs;
1128+
_binProbs = binProbs;
11271129
}
11281130

11291131
private NaiveCalibrator(IHostEnvironment env, ModelLoadContext ctx)
@@ -1147,9 +1149,9 @@ private NaiveCalibrator(IHostEnvironment env, ModelLoadContext ctx)
11471149
Min = ctx.Reader.ReadFloat();
11481150
_host.CheckDecode(FloatUtils.IsFinite(Min));
11491151

1150-
BinProbs = ctx.Reader.ReadFloatArray();
1151-
_host.CheckDecode(Utils.Size(BinProbs) > 0);
1152-
_host.CheckDecode(BinProbs.All(x => (0 <= x && x <= 1)));
1152+
_binProbs = ctx.Reader.ReadFloatArray();
1153+
_host.CheckDecode(Utils.Size(_binProbs) > 0);
1154+
_host.CheckDecode(_binProbs.All(x => (0 <= x && x <= 1)));
11531155
}
11541156

11551157
private static NaiveCalibrator Create(IHostEnvironment env, ModelLoadContext ctx)
@@ -1180,7 +1182,7 @@ private void SaveCore(ModelSaveContext ctx)
11801182
ctx.Writer.Write(sizeof(float));
11811183
ctx.Writer.Write(BinSize);
11821184
ctx.Writer.Write(Min);
1183-
ctx.Writer.WriteSingleArray(BinProbs);
1185+
ctx.Writer.WriteSingleArray(_binProbs);
11841186
}
11851187

11861188
/// <summary>
@@ -1190,8 +1192,8 @@ public float PredictProbability(float output)
11901192
{
11911193
if (float.IsNaN(output))
11921194
return output;
1193-
int binIdx = GetBinIdx(output, Min, BinSize, BinProbs.Length);
1194-
return BinProbs[binIdx];
1195+
int binIdx = GetBinIdx(output, Min, BinSize, _binProbs.Length);
1196+
return _binProbs[binIdx];
11951197
}
11961198

11971199
// get the bin for a given output
@@ -1205,11 +1207,6 @@ internal static int GetBinIdx(float output, float min, float binSize, int numBin
12051207
return binIdx;
12061208
}
12071209

1208-
/// <summary> Get the summary of current calibrator settings </summary>
1209-
public string GetSummary()
1210-
{
1211-
return string.Format("Naive Calibrator has {0} bins, starting at {1}, with bin size of {2}", BinProbs.Length, Min, BinSize);
1212-
}
12131210
}
12141211

12151212
/// <summary>
@@ -1218,8 +1215,91 @@ public string GetSummary()
12181215
[BestFriend]
12191216
internal abstract class CalibratorTrainerBase : ICalibratorTrainer
12201217
{
1218+
public sealed class DataStore : IEnumerable<DataStore.DataItem>
1219+
{
1220+
public readonly struct DataItem
1221+
{
1222+
// The actual binary label of this example.
1223+
public readonly bool Target;
1224+
// The weight associated with this example.
1225+
public readonly float Weight;
1226+
// The output of the example.
1227+
public readonly float Score;
1228+
1229+
public DataItem(bool target, float weight, float score)
1230+
{
1231+
Target = target;
1232+
Weight = weight;
1233+
Score = score;
1234+
}
1235+
}
1236+
1237+
// REVIEW: Should probably be a long.
1238+
private int _itemsSeen;
1239+
private readonly Random _random;
1240+
1241+
private static int _randSeed;
1242+
1243+
private readonly int _capacity;
1244+
private DataItem[] _data;
1245+
private bool _dataSorted;
1246+
1247+
public DataStore()
1248+
: this(1000000)
1249+
{
1250+
}
1251+
1252+
public DataStore(int capacity)
1253+
{
1254+
Contracts.CheckParam(capacity > 0, nameof(capacity), "must be positive");
1255+
1256+
_capacity = capacity;
1257+
_data = new DataItem[Math.Min(4, capacity)];
1258+
// REVIEW: Horrifying. At a point when we have the IHost stuff plumbed through
1259+
// calibrator training and also have the appetite to change a bunch of baselines, this
1260+
// should be seeded using the host random.
1261+
_random = new System.Random(System.Threading.Interlocked.Increment(ref _randSeed) - 1);
1262+
}
1263+
1264+
/// <summary>
1265+
/// An enumerator over the <see cref="DataItem"/> entries sorted by score.
1266+
/// </summary>
1267+
/// <returns></returns>
1268+
public IEnumerator<DataItem> GetEnumerator()
1269+
{
1270+
if (!_dataSorted)
1271+
{
1272+
var comp = Comparer<DataItem>.Create((x, y) => x.Score.CompareTo(y.Score));
1273+
Array.Sort(_data, 0, Math.Min(_itemsSeen, _capacity), comp);
1274+
_dataSorted = true;
1275+
}
1276+
return _data.Take(_itemsSeen).GetEnumerator();
1277+
}
1278+
1279+
IEnumerator IEnumerable.GetEnumerator()
1280+
{
1281+
return GetEnumerator();
1282+
}
1283+
1284+
public void AddToStore(float score, bool isPositive, float weight)
1285+
{
1286+
// Can't calibrate NaN scores.
1287+
if (weight == 0 || float.IsNaN(score))
1288+
return;
1289+
int index = _itemsSeen++;
1290+
if (_itemsSeen <= _capacity)
1291+
Utils.EnsureSize(ref _data, _itemsSeen, _capacity);
1292+
else
1293+
{
1294+
index = _random.Next(_itemsSeen); // 0 to items_seen - 1.
1295+
if (index >= _capacity) // Don't keep it.
1296+
return;
1297+
}
1298+
_data[index] = new DataItem(isPositive, weight, score);
1299+
}
1300+
}
12211301
protected readonly IHost Host;
1222-
protected CalibrationDataStore Data;
1302+
protected DataStore Data;
12231303
protected const int DefaultMaxNumSamples = 1000000;
12241304
protected int MaxNumSamples;
12251305

@@ -1239,7 +1319,7 @@ protected CalibratorTrainerBase(IHostEnvironment env, string name)
12391319
bool ICalibratorTrainer.ProcessTrainingExample(float output, bool labelIs1, float weight)
12401320
{
12411321
if (Data == null)
1242-
Data = new CalibrationDataStore(MaxNumSamples);
1322+
Data = new DataStore(MaxNumSamples);
12431323
Data.AddToStore(output, labelIs1, weight);
12441324
return true;
12451325
}
@@ -1485,15 +1565,21 @@ private static VersionInfo GetVersionInfo()
14851565

14861566
private readonly IHost _host;
14871567

1568+
/// <summary>
1569+
/// Slope value for this calibrator.
1570+
/// </summary>
14881571
public Double Slope { get; }
1572+
/// <summary>
1573+
/// Offset value for this calibrator
1574+
/// </summary>
14891575
public Double Offset { get; }
14901576
bool ICanSavePfa.CanSavePfa => true;
14911577
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true;
14921578

14931579
/// <summary>
14941580
/// Initializes a new instance of <see cref="PlattCalibrator"/>.
14951581
/// </summary>
1496-
public PlattCalibrator(IHostEnvironment env, Double slope, Double offset)
1582+
internal PlattCalibrator(IHostEnvironment env, Double slope, Double offset)
14971583
{
14981584
Contracts.CheckValue(env, nameof(env));
14991585
_host = env.Register(RegistrationName);
@@ -1556,14 +1642,15 @@ private void SaveCore(ModelSaveContext ctx)
15561642
}
15571643
}
15581644

1645+
/// <summary> Given a classifier output, produce the probability.</summary>
15591646
public float PredictProbability(float output)
15601647
{
15611648
if (float.IsNaN(output))
15621649
return output;
15631650
return PredictProbability(output, Slope, Offset);
15641651
}
15651652

1566-
public static float PredictProbability(float output, Double a, Double b)
1653+
internal static float PredictProbability(float output, Double a, Double b)
15671654
{
15681655
return (float)(1 / (1 + Math.Exp(a * output + b)));
15691656
}
@@ -1597,11 +1684,6 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] scoreProbablityColu
15971684
return true;
15981685
}
15991686

1600-
public string GetSummary()
1601-
{
1602-
return string.Format("Platt calibrator parameters: A={0}, B={1}", Slope, Offset);
1603-
}
1604-
16051687
IParameterMixer IParameterMixer.CombineParameters(IList<IParameterMixer> calibrators)
16061688
{
16071689
Double a = 0;
@@ -1703,12 +1785,17 @@ public override ICalibrator CreateCalibrator(IChannel ch)
17031785

17041786
/// <summary>
17051787
/// The pair-adjacent violators calibrator.
1706-
/// The function that is implemented by this calibrator is:
1707-
/// f(x) = v_i, if minX_i &lt;= x &lt;= maxX_i
1708-
/// = linear interpolate between v_i and v_i+1, if maxX_i &lt; x &lt; minX_i+1
1709-
/// = v_0, if x &lt; minX_0
1710-
/// = v_n, if x &gt; maxX_n
17111788
/// </summary>
1789+
/// <remarks>
1790+
/// The function that is implemented by this calibrator is:
1791+
/// P(x) =
1792+
/// <list type="bullet">
1793+
/// <item><description><see cref="Values"/>[i], if <see cref="Mins"/>[i] &lt;= x &lt;= <see cref="Maxes"/>[i]</description>></item>
1794+
/// <item> <description>Linear interpolation between <see cref="Values"/>[i] and <see cref="Values"/>[i+1], if <see cref="Maxes"/>[i] &lt; x &lt; <see cref="Mins"/>[i+1]</description></item>
1795+
/// <item><description><see cref="Values"/>[0], if x &lt; <see cref="Mins"/>[0]</description></item>
1796+
/// <item><description><see cref="Values"/>[n], if x &gt; <see cref="Maxes"/>[n]</description></item>
1797+
///</list>
1798+
/// </remarks>
17121799
public sealed class PavCalibrator : ICalibrator, ICanSaveInBinaryFormat
17131800
{
17141801
internal const string LoaderSignature = "PAVCaliExec";
@@ -1731,8 +1818,17 @@ private static VersionInfo GetVersionInfo()
17311818
private const float MaxToReturn = 1 - Epsilon; // max predicted is 1 - min;
17321819

17331820
private readonly IHost _host;
1821+
/// <summary>
1822+
/// Bottom borders of PAV intervals.
1823+
/// </summary>
17341824
public readonly ImmutableArray<float> Mins;
1825+
/// <summary>
1826+
/// Upper borders of PAV intervals.
1827+
/// </summary>
17351828
public readonly ImmutableArray<float> Maxes;
1829+
/// <summary>
1830+
/// Values of PAV intervals.
1831+
/// </summary>
17361832
public readonly ImmutableArray<float> Values;
17371833

17381834
/// <summary>
@@ -1742,7 +1838,7 @@ private static VersionInfo GetVersionInfo()
17421838
/// <param name="mins">The minimum values for each piece.</param>
17431839
/// <param name="maxes">The maximum values for each piece.</param>
17441840
/// <param name="values">The actual values for each piece.</param>
1745-
public PavCalibrator(IHostEnvironment env, ImmutableArray<float> mins, ImmutableArray<float> maxes, ImmutableArray<float> values)
1841+
internal PavCalibrator(IHostEnvironment env, ImmutableArray<float> mins, ImmutableArray<float> maxes, ImmutableArray<float> values)
17461842
{
17471843
Contracts.AssertValue(env);
17481844
_host = env.Register(RegistrationName);
@@ -1851,6 +1947,7 @@ private void SaveCore(ModelSaveContext ctx)
18511947
_host.CheckDecode(valuePrev <= 1);
18521948
}
18531949

1950+
/// <summary> Given a classifier output, produce the probability.</summary>
18541951
public float PredictProbability(float output)
18551952
{
18561953
if (float.IsNaN(output))
@@ -1890,95 +1987,6 @@ private float FindValue(float score)
18901987
float t = (score - Maxes[pos - 1]) / (Mins[pos] - Maxes[pos - 1]);
18911988
return Values[pos - 1] + t * (Values[pos] - Values[pos - 1]);
18921989
}
1893-
1894-
public string GetSummary()
1895-
{
1896-
return string.Format("PAV calibrator with {0} intervals", Mins.Length);
1897-
}
1898-
}
1899-
1900-
public sealed class CalibrationDataStore : IEnumerable<CalibrationDataStore.DataItem>
1901-
{
1902-
public readonly struct DataItem
1903-
{
1904-
// The actual binary label of this example.
1905-
public readonly bool Target;
1906-
// The weight associated with this example.
1907-
public readonly float Weight;
1908-
// The output of the example.
1909-
public readonly float Score;
1910-
1911-
public DataItem(bool target, float weight, float score)
1912-
{
1913-
Target = target;
1914-
Weight = weight;
1915-
Score = score;
1916-
}
1917-
}
1918-
1919-
// REVIEW: Should probably be a long.
1920-
private int _itemsSeen;
1921-
private readonly Random _random;
1922-
1923-
private static int _randSeed;
1924-
1925-
private readonly int _capacity;
1926-
private DataItem[] _data;
1927-
private bool _dataSorted;
1928-
1929-
public CalibrationDataStore()
1930-
: this(1000000)
1931-
{
1932-
}
1933-
1934-
public CalibrationDataStore(int capacity)
1935-
{
1936-
Contracts.CheckParam(capacity > 0, nameof(capacity), "must be positive");
1937-
1938-
_capacity = capacity;
1939-
_data = new DataItem[Math.Min(4, capacity)];
1940-
// REVIEW: Horrifying. At a point when we have the IHost stuff plumbed through
1941-
// calibrator training and also have the appetite to change a bunch of baselines, this
1942-
// should be seeded using the host random.
1943-
_random = new System.Random(System.Threading.Interlocked.Increment(ref _randSeed) - 1);
1944-
}
1945-
1946-
/// <summary>
1947-
/// An enumerator over the <see cref="DataItem"/> entries sorted by score.
1948-
/// </summary>
1949-
/// <returns></returns>
1950-
public IEnumerator<DataItem> GetEnumerator()
1951-
{
1952-
if (!_dataSorted)
1953-
{
1954-
var comp = Comparer<DataItem>.Create((x, y) => x.Score.CompareTo(y.Score));
1955-
Array.Sort(_data, 0, Math.Min(_itemsSeen, _capacity), comp);
1956-
_dataSorted = true;
1957-
}
1958-
return _data.Take(_itemsSeen).GetEnumerator();
1959-
}
1960-
1961-
IEnumerator IEnumerable.GetEnumerator()
1962-
{
1963-
return GetEnumerator();
1964-
}
1965-
1966-
public void AddToStore(float score, bool isPositive, float weight)
1967-
{
1968-
// Can't calibrate NaN scores.
1969-
if (weight == 0 || float.IsNaN(score))
1970-
return;
1971-
int index = _itemsSeen++;
1972-
if (_itemsSeen <= _capacity)
1973-
Utils.EnsureSize(ref _data, _itemsSeen, _capacity);
1974-
else
1975-
{
1976-
index = _random.Next(_itemsSeen); // 0 to items_seen - 1.
1977-
if (index >= _capacity) // Don't keep it.
1978-
return;
1979-
}
1980-
_data[index] = new DataItem(isPositive, weight, score);
1981-
}
19821990
}
19831991

19841992
internal static class Calibrate

0 commit comments

Comments
 (0)