@@ -218,13 +218,18 @@ private protected override void CheckDataValid(IChannel ch, RoleMappedData data)
218
218
219
219
private protected override void InitializeBeforeTraining ( )
220
220
{
221
- _numClass = - 1 ; //MYTODO: Include more initializations, of TrainedEnsemble, for example?
221
+ _numClass = - 1 ;
222
222
_tlcNumClass = 0 ;
223
- }
223
+
224
+ //MYTODO: Include more initializations, of TrainedEnsemble, for example?
225
+ //For example:
226
+ //TrainedEnsemble = null;
227
+ }
224
228
225
229
private protected override void ConvertNaNLabels ( IChannel ch , RoleMappedData data , float [ ] labels )
226
230
{
227
231
// Only initialize one time.
232
+
228
233
if ( _numClass < 0 )
229
234
{
230
235
float minLabel = float . MaxValue ;
@@ -261,6 +266,13 @@ private protected override void ConvertNaNLabels(IChannel ch, RoleMappedData dat
261
266
_tlcNumClass = ( int ) maxLabel + 1 ;
262
267
}
263
268
}
269
+
270
+ // If there are NaN labels, they are converted to be equal to _tlcNumClass (i.e. _numClass - 1).
271
+ // This is done because NaN labels are going to be seen as
272
+ // an extra different class, and thus, when training the model in the WrappedLightGbmTraining class
273
+ // a total of _numClass classes are considered. But, when creating the Predictors, only _tlcNumClass number of
274
+ // classes are considered ignoring the extra class of NaN labels.
275
+
264
276
float defaultLabel = _numClass - 1 ;
265
277
for ( int i = 0 ; i < labels . Length ; ++ i )
266
278
if ( float . IsNaN ( labels [ i ] ) )
0 commit comments