-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Cleanup calibrators #2601
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cleanup calibrators #2601
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1110,20 +1110,22 @@ private static VersionInfo GetVersionInfo() | |
public readonly float Min; | ||
|
||
/// <summary> The value of probability in each bin.</summary> | ||
public readonly float[] BinProbs; | ||
public IReadOnlyList<float> BinProbs => _binProbs; | ||
|
||
private readonly float[] _binProbs; | ||
|
||
/// <summary> Initializes a new instance of <see cref="NaiveCalibrator"/>.</summary> | ||
/// <param name="env">The <see cref="IHostEnvironment"/> to use.</param> | ||
/// <param name="min">The minimum value in the first bin.</param> | ||
/// <param name="binProbs">The values of the probability in each bin.</param> | ||
/// <param name="binSize">The bin size.</param> | ||
public NaiveCalibrator(IHostEnvironment env, float min, float binSize, float[] binProbs) | ||
internal NaiveCalibrator(IHostEnvironment env, float min, float binSize, float[] binProbs) | ||
{ | ||
Contracts.CheckValue(env, nameof(env)); | ||
_host = env.Register(RegistrationName); | ||
Min = min; | ||
BinSize = binSize; | ||
BinProbs = binProbs; | ||
_binProbs = binProbs; | ||
} | ||
|
||
private NaiveCalibrator(IHostEnvironment env, ModelLoadContext ctx) | ||
|
@@ -1147,9 +1149,9 @@ private NaiveCalibrator(IHostEnvironment env, ModelLoadContext ctx) | |
Min = ctx.Reader.ReadFloat(); | ||
_host.CheckDecode(FloatUtils.IsFinite(Min)); | ||
|
||
BinProbs = ctx.Reader.ReadFloatArray(); | ||
_host.CheckDecode(Utils.Size(BinProbs) > 0); | ||
_host.CheckDecode(BinProbs.All(x => (0 <= x && x <= 1))); | ||
_binProbs = ctx.Reader.ReadFloatArray(); | ||
_host.CheckDecode(Utils.Size(_binProbs) > 0); | ||
_host.CheckDecode(_binProbs.All(x => (0 <= x && x <= 1))); | ||
} | ||
|
||
private static NaiveCalibrator Create(IHostEnvironment env, ModelLoadContext ctx) | ||
|
@@ -1180,7 +1182,7 @@ private void SaveCore(ModelSaveContext ctx) | |
ctx.Writer.Write(sizeof(float)); | ||
ctx.Writer.Write(BinSize); | ||
ctx.Writer.Write(Min); | ||
ctx.Writer.WriteSingleArray(BinProbs); | ||
ctx.Writer.WriteSingleArray(_binProbs); | ||
} | ||
|
||
/// <summary> | ||
|
@@ -1190,8 +1192,8 @@ public float PredictProbability(float output) | |
{ | ||
if (float.IsNaN(output)) | ||
return output; | ||
int binIdx = GetBinIdx(output, Min, BinSize, BinProbs.Length); | ||
return BinProbs[binIdx]; | ||
int binIdx = GetBinIdx(output, Min, BinSize, _binProbs.Length); | ||
return _binProbs[binIdx]; | ||
} | ||
|
||
// get the bin for a given output | ||
|
@@ -1205,11 +1207,6 @@ internal static int GetBinIdx(float output, float min, float binSize, int numBin | |
return binIdx; | ||
} | ||
|
||
/// <summary> Get the summary of current calibrator settings </summary> | ||
public string GetSummary() | ||
{ | ||
return string.Format("Naive Calibrator has {0} bins, starting at {1}, with bin size of {2}", BinProbs.Length, Min, BinSize); | ||
} | ||
} | ||
|
||
/// <summary> | ||
|
@@ -1218,8 +1215,91 @@ public string GetSummary() | |
[BestFriend] | ||
internal abstract class CalibratorTrainerBase : ICalibratorTrainer | ||
{ | ||
public sealed class DataStore : IEnumerable<DataStore.DataItem> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Really vague name. Can we think of something more precise? |
||
{ | ||
public readonly struct DataItem | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Really vague name for a class visible throughout the assembly. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As far as I see it is a nested class inside the internal There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A) it's part of internal class. In reply to: 258183106 [](ancestors = 258183106) |
||
{ | ||
// The actual binary label of this example. | ||
public readonly bool Target; | ||
// The weight associated with this example. | ||
public readonly float Weight; | ||
// The output of the example. | ||
public readonly float Score; | ||
|
||
public DataItem(bool target, float weight, float score) | ||
{ | ||
Target = target; | ||
Weight = weight; | ||
Score = score; | ||
} | ||
} | ||
|
||
// REVIEW: Should probably be a long. | ||
private int _itemsSeen; | ||
private readonly Random _random; | ||
|
||
private static int _randSeed; | ||
|
||
private readonly int _capacity; | ||
private DataItem[] _data; | ||
private bool _dataSorted; | ||
|
||
public DataStore() | ||
: this(1000000) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Any reason? |
||
{ | ||
} | ||
|
||
public DataStore(int capacity) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Maybe instead of two constructors, we just do |
||
{ | ||
Contracts.CheckParam(capacity > 0, nameof(capacity), "must be positive"); | ||
|
||
_capacity = capacity; | ||
_data = new DataItem[Math.Min(4, capacity)]; | ||
// REVIEW: Horrifying. At a point when we have the IHost stuff plumbed through | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The horror |
||
// calibrator training and also have the appetite to change a bunch of baselines, this | ||
// should be seeded using the host random. | ||
_random = new System.Random(System.Threading.Interlocked.Increment(ref _randSeed) - 1); | ||
} | ||
|
||
/// <summary> | ||
/// An enumerator over the <see cref="DataItem"/> entries sorted by score. | ||
/// </summary> | ||
/// <returns></returns> | ||
public IEnumerator<DataItem> GetEnumerator() | ||
{ | ||
if (!_dataSorted) | ||
{ | ||
var comp = Comparer<DataItem>.Create((x, y) => x.Score.CompareTo(y.Score)); | ||
Array.Sort(_data, 0, Math.Min(_itemsSeen, _capacity), comp); | ||
_dataSorted = true; | ||
} | ||
return _data.Take(_itemsSeen).GetEnumerator(); | ||
} | ||
|
||
IEnumerator IEnumerable.GetEnumerator() | ||
{ | ||
return GetEnumerator(); | ||
} | ||
|
||
public void AddToStore(float score, bool isPositive, float weight) | ||
{ | ||
// Can't calibrate NaN scores. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Isn't it BAD to get NaN scores? |
||
if (weight == 0 || float.IsNaN(score)) | ||
return; | ||
int index = _itemsSeen++; | ||
if (_itemsSeen <= _capacity) | ||
Utils.EnsureSize(ref _data, _itemsSeen, _capacity); | ||
else | ||
{ | ||
index = _random.Next(_itemsSeen); // 0 to items_seen - 1. | ||
if (index >= _capacity) // Don't keep it. | ||
return; | ||
} | ||
_data[index] = new DataItem(isPositive, weight, score); | ||
} | ||
} | ||
protected readonly IHost Host; | ||
protected CalibrationDataStore Data; | ||
protected DataStore Data; | ||
protected const int DefaultMaxNumSamples = 1000000; | ||
protected int MaxNumSamples; | ||
|
||
|
@@ -1239,7 +1319,7 @@ protected CalibratorTrainerBase(IHostEnvironment env, string name) | |
bool ICalibratorTrainer.ProcessTrainingExample(float output, bool labelIs1, float weight) | ||
{ | ||
if (Data == null) | ||
Data = new CalibrationDataStore(MaxNumSamples); | ||
Data = new DataStore(MaxNumSamples); | ||
Data.AddToStore(output, labelIs1, weight); | ||
return true; | ||
} | ||
|
@@ -1485,15 +1565,21 @@ private static VersionInfo GetVersionInfo() | |
|
||
private readonly IHost _host; | ||
|
||
/// <summary> | ||
/// Slope value for this calibrator. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Maybe use the phrases |
||
/// </summary> | ||
public Double Slope { get; } | ||
/// <summary> | ||
/// Offset value for this calibrator | ||
/// </summary> | ||
public Double Offset { get; } | ||
bool ICanSavePfa.CanSavePfa => true; | ||
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true; | ||
|
||
/// <summary> | ||
/// Initializes a new instance of <see cref="PlattCalibrator"/>. | ||
/// </summary> | ||
public PlattCalibrator(IHostEnvironment env, Double slope, Double offset) | ||
internal PlattCalibrator(IHostEnvironment env, Double slope, Double offset) | ||
{ | ||
Contracts.CheckValue(env, nameof(env)); | ||
_host = env.Register(RegistrationName); | ||
|
@@ -1556,14 +1642,15 @@ private void SaveCore(ModelSaveContext ctx) | |
} | ||
} | ||
|
||
/// <summary> Given a classifier output, produce the probability.</summary> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
break pls |
||
public float PredictProbability(float output) | ||
{ | ||
if (float.IsNaN(output)) | ||
return output; | ||
return PredictProbability(output, Slope, Offset); | ||
} | ||
|
||
public static float PredictProbability(float output, Double a, Double b) | ||
internal static float PredictProbability(float output, Double a, Double b) | ||
{ | ||
return (float)(1 / (1 + Math.Exp(a * output + b))); | ||
} | ||
|
@@ -1597,11 +1684,6 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] scoreProbablityColu | |
return true; | ||
} | ||
|
||
public string GetSummary() | ||
{ | ||
return string.Format("Platt calibrator parameters: A={0}, B={1}", Slope, Offset); | ||
} | ||
|
||
IParameterMixer IParameterMixer.CombineParameters(IList<IParameterMixer> calibrators) | ||
{ | ||
Double a = 0; | ||
|
@@ -1703,12 +1785,17 @@ public override ICalibrator CreateCalibrator(IChannel ch) | |
|
||
/// <summary> | ||
/// The pair-adjacent violators calibrator. | ||
/// The function that is implemented by this calibrator is: | ||
/// f(x) = v_i, if minX_i <= x <= maxX_i | ||
/// = linear interpolate between v_i and v_i+1, if maxX_i < x < minX_i+1 | ||
/// = v_0, if x < minX_0 | ||
/// = v_n, if x > maxX_n | ||
/// </summary> | ||
/// <remarks> | ||
/// The function that is implemented by this calibrator is: | ||
/// P(x) = | ||
/// <list type="bullet"> | ||
/// <item><description><see cref="Values"/>[i], if <see cref="Mins"/>[i] <= x <= <see cref="Maxes"/>[i]</description>></item> | ||
/// <item> <description>Linear interpolation between <see cref="Values"/>[i] and <see cref="Values"/>[i+1], if <see cref="Maxes"/>[i] < x < <see cref="Mins"/>[i+1]</description></item> | ||
/// <item><description><see cref="Values"/>[0], if x < <see cref="Mins"/>[0]</description></item> | ||
/// <item><description><see cref="Values"/>[n], if x > <see cref="Maxes"/>[n]</description></item> | ||
///</list> | ||
/// </remarks> | ||
public sealed class PavCalibrator : ICalibrator, ICanSaveInBinaryFormat | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If it's not too hard, can you provide a link to any resources about this kind of calibration? You could also file an issue and assign it to the docs folks. |
||
{ | ||
internal const string LoaderSignature = "PAVCaliExec"; | ||
|
@@ -1731,8 +1818,17 @@ private static VersionInfo GetVersionInfo() | |
private const float MaxToReturn = 1 - Epsilon; // max predicted is 1 - min; | ||
|
||
private readonly IHost _host; | ||
/// <summary> | ||
/// Bottom borders of PAV intervals. | ||
/// </summary> | ||
public readonly ImmutableArray<float> Mins; | ||
/// <summary> | ||
/// Upper borders of PAV intervals. | ||
/// </summary> | ||
public readonly ImmutableArray<float> Maxes; | ||
/// <summary> | ||
/// Values of PAV intervals. | ||
/// </summary> | ||
public readonly ImmutableArray<float> Values; | ||
|
||
/// <summary> | ||
|
@@ -1742,7 +1838,7 @@ private static VersionInfo GetVersionInfo() | |
/// <param name="mins">The minimum values for each piece.</param> | ||
/// <param name="maxes">The maximum values for each piece.</param> | ||
/// <param name="values">The actual values for each piece.</param> | ||
public PavCalibrator(IHostEnvironment env, ImmutableArray<float> mins, ImmutableArray<float> maxes, ImmutableArray<float> values) | ||
internal PavCalibrator(IHostEnvironment env, ImmutableArray<float> mins, ImmutableArray<float> maxes, ImmutableArray<float> values) | ||
{ | ||
Contracts.AssertValue(env); | ||
_host = env.Register(RegistrationName); | ||
|
@@ -1851,6 +1947,7 @@ private void SaveCore(ModelSaveContext ctx) | |
_host.CheckDecode(valuePrev <= 1); | ||
} | ||
|
||
/// <summary> Given a classifier output, produce the probability.</summary> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
break There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought the style was to break the xml tags onto their own lines, but I could be mistaken. In reply to: 258188999 [](ancestors = 258188999,258187676) |
||
public float PredictProbability(float output) | ||
{ | ||
if (float.IsNaN(output)) | ||
|
@@ -1890,95 +1987,6 @@ private float FindValue(float score) | |
float t = (score - Maxes[pos - 1]) / (Mins[pos] - Maxes[pos - 1]); | ||
return Values[pos - 1] + t * (Values[pos] - Values[pos - 1]); | ||
} | ||
|
||
public string GetSummary() | ||
{ | ||
return string.Format("PAV calibrator with {0} intervals", Mins.Length); | ||
} | ||
} | ||
|
||
public sealed class CalibrationDataStore : IEnumerable<CalibrationDataStore.DataItem> | ||
{ | ||
public readonly struct DataItem | ||
{ | ||
// The actual binary label of this example. | ||
public readonly bool Target; | ||
// The weight associated with this example. | ||
public readonly float Weight; | ||
// The output of the example. | ||
public readonly float Score; | ||
|
||
public DataItem(bool target, float weight, float score) | ||
{ | ||
Target = target; | ||
Weight = weight; | ||
Score = score; | ||
} | ||
} | ||
|
||
// REVIEW: Should probably be a long. | ||
private int _itemsSeen; | ||
private readonly Random _random; | ||
|
||
private static int _randSeed; | ||
|
||
private readonly int _capacity; | ||
private DataItem[] _data; | ||
private bool _dataSorted; | ||
|
||
public CalibrationDataStore() | ||
: this(1000000) | ||
{ | ||
} | ||
|
||
public CalibrationDataStore(int capacity) | ||
{ | ||
Contracts.CheckParam(capacity > 0, nameof(capacity), "must be positive"); | ||
|
||
_capacity = capacity; | ||
_data = new DataItem[Math.Min(4, capacity)]; | ||
// REVIEW: Horrifying. At a point when we have the IHost stuff plumbed through | ||
// calibrator training and also have the appetite to change a bunch of baselines, this | ||
// should be seeded using the host random. | ||
_random = new System.Random(System.Threading.Interlocked.Increment(ref _randSeed) - 1); | ||
} | ||
|
||
/// <summary> | ||
/// An enumerator over the <see cref="DataItem"/> entries sorted by score. | ||
/// </summary> | ||
/// <returns></returns> | ||
public IEnumerator<DataItem> GetEnumerator() | ||
{ | ||
if (!_dataSorted) | ||
{ | ||
var comp = Comparer<DataItem>.Create((x, y) => x.Score.CompareTo(y.Score)); | ||
Array.Sort(_data, 0, Math.Min(_itemsSeen, _capacity), comp); | ||
_dataSorted = true; | ||
} | ||
return _data.Take(_itemsSeen).GetEnumerator(); | ||
} | ||
|
||
IEnumerator IEnumerable.GetEnumerator() | ||
{ | ||
return GetEnumerator(); | ||
} | ||
|
||
public void AddToStore(float score, bool isPositive, float weight) | ||
{ | ||
// Can't calibrate NaN scores. | ||
if (weight == 0 || float.IsNaN(score)) | ||
return; | ||
int index = _itemsSeen++; | ||
if (_itemsSeen <= _capacity) | ||
Utils.EnsureSize(ref _data, _itemsSeen, _capacity); | ||
else | ||
{ | ||
index = _random.Next(_itemsSeen); // 0 to items_seen - 1. | ||
if (index >= _capacity) // Don't keep it. | ||
return; | ||
} | ||
_data[index] = new DataItem(isPositive, weight, score); | ||
} | ||
} | ||
|
||
internal static class Calibrate | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ugh. This code makes my eyes bleed but if no one has complained after all these years we can probably let it stand for a bit. :)