Skip to content

Commit 1aaa77f

Browse files
committed
Modified MYTODO comments
1 parent 0902927 commit 1aaa77f

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,13 +218,18 @@ private protected override void CheckDataValid(IChannel ch, RoleMappedData data)
218218

219219
private protected override void InitializeBeforeTraining()
220220
{
221-
_numClass = -1; //MYTODO: Include more initializations, of TrainedEnsemble, for example?
221+
_numClass = -1;
222222
_tlcNumClass = 0;
223-
}
223+
224+
//MYTODO: Include more initializations, of TrainedEnsemble, for example?
225+
//For example:
226+
//TrainedEnsemble = null;
227+
}
224228

225229
private protected override void ConvertNaNLabels(IChannel ch, RoleMappedData data, float[] labels)
226230
{
227231
// Only initialize one time.
232+
228233
if (_numClass < 0)
229234
{
230235
float minLabel = float.MaxValue;
@@ -261,6 +266,13 @@ private protected override void ConvertNaNLabels(IChannel ch, RoleMappedData dat
261266
_tlcNumClass = (int)maxLabel + 1;
262267
}
263268
}
269+
270+
// If there are NaN labels, they are converted to be equal to _tlcNumClass (i.e. _numClass - 1).
271+
// This is done because NaN labels are going to be seen as
272+
// an extra different class, and thus, when training the model in the WrappedLightGbmTraining class
273+
// a total of _numClass classes are considered. But, when creating the Predictors, only _tlcNumClass number of
274+
// classes are considered ignoring the extra class of NaN labels.
275+
264276
float defaultLabel = _numClass - 1;
265277
for (int i = 0; i < labels.Length; ++i)
266278
if (float.IsNaN(labels[i]))

src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ private protected override TModel TrainModelCore(TrainContext context)
374374
}
375375

376376
private protected virtual void InitializeBeforeTraining()
377-
{ return; }
377+
{ return; } // MYTODO: Is there a better way to avoid having to do this? An abstract method?
378378

379379
private void InitParallelTraining()
380380
{

0 commit comments

Comments
 (0)