Skip to content

Commit a5d31a6

Browse files
committed
PR feedback.
1 parent f44c46a commit a5d31a6

12 files changed

+149
-148
lines changed

src/Microsoft.ML.FastTree/BoostingFastTree.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ internal BoostingFastTreeTrainerBase(IHostEnvironment env,
3232
Args.LearningRates = learningRate;
3333
}
3434

35-
internal override void CheckArgs(IChannel ch)
35+
private protected override void CheckArgs(IChannel ch)
3636
{
3737
if (Args.OptimizationAlgorithm == BoostedTreeArgs.OptimizationAlgorithmType.AcceleratedGradientDescent)
3838
Args.UseLineSearch = true;
@@ -102,7 +102,7 @@ private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(
102102
return optimizationAlgorithm;
103103
}
104104

105-
internal override IGradientAdjuster MakeGradientWrapper(IChannel ch)
105+
private protected override IGradientAdjuster MakeGradientWrapper(IChannel ch)
106106
{
107107
if (!Args.BestStepRankingRegressionTrees)
108108
return base.MakeGradientWrapper(ch);
@@ -115,7 +115,7 @@ internal override IGradientAdjuster MakeGradientWrapper(IChannel ch)
115115
return new BestStepRegressionGradientWrapper();
116116
}
117117

118-
internal override bool ShouldStop(IChannel ch, ref IEarlyStoppingCriterion earlyStoppingRule, ref int bestIteration)
118+
private protected override bool ShouldStop(IChannel ch, ref IEarlyStoppingCriterion earlyStoppingRule, ref int bestIteration)
119119
{
120120
if (Args.EarlyStoppingRule == null)
121121
return false;
@@ -147,7 +147,7 @@ internal override bool ShouldStop(IChannel ch, ref IEarlyStoppingCriterion early
147147
return shouldStop;
148148
}
149149

150-
internal override int GetBestIteration(IChannel ch)
150+
private protected override int GetBestIteration(IChannel ch)
151151
{
152152
int bestIteration = Ensemble.NumTrees;
153153
if (!Args.WriteLastEnsemble && PruningTest != null)
@@ -169,7 +169,7 @@ internal double BsrMaxTreeOutput()
169169
return -1;
170170
}
171171

172-
internal override bool ShouldRandomStartOptimizer()
172+
private protected override bool ShouldRandomStartOptimizer()
173173
{
174174
return Args.RandomStart;
175175
}

src/Microsoft.ML.FastTree/FastTree.cs

Lines changed: 70 additions & 69 deletions
Large diffs are not rendered by default.

src/Microsoft.ML.FastTree/FastTreeClassification.cs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,11 @@ private static VersionInfo GetVersionInfo()
6262
loaderAssemblyName: typeof(FastTreeBinaryModelParameters).Assembly.FullName);
6363
}
6464

65-
internal override uint VerNumFeaturesSerialized => 0x00010002;
65+
private protected override uint VerNumFeaturesSerialized => 0x00010002;
6666

67-
internal override uint VerDefaultValueSerialized => 0x00010004;
67+
private protected override uint VerDefaultValueSerialized => 0x00010004;
6868

69-
internal override uint VerCategoricalSplitSerialized => 0x00010005;
69+
private protected override uint VerCategoricalSplitSerialized => 0x00010005;
7070

7171
internal FastTreeBinaryModelParameters(IHostEnvironment env, InternalTreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
7272
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
@@ -188,7 +188,7 @@ private protected override CalibratedModelParametersBase<FastTreeBinaryModelPara
188188
return new FeatureWeightsCalibratedModelParameters<FastTreeBinaryModelParameters, PlattCalibrator>(Host, pred, cali);
189189
}
190190

191-
internal override ObjectiveFunctionBase ConstructObjFunc(IChannel ch)
191+
private protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch)
192192
{
193193
return new ObjectiveImpl(
194194
TrainSet,
@@ -223,18 +223,18 @@ private IEnumerable<bool> GetClassificationLabelsFromRatings(Dataset set)
223223
return set.Ratings.Select(x => x >= 1);
224224
}
225225

226-
internal override void PrepareLabels(IChannel ch)
226+
private protected override void PrepareLabels(IChannel ch)
227227
{
228228
_trainSetLabels = GetClassificationLabelsFromRatings(TrainSet).ToArray(TrainSet.NumDocs);
229229
//Here we set regression labels to what is in bin file if the values were not overriden with floats
230230
}
231231

232-
internal override Test ConstructTestForTrainingData()
232+
private protected override Test ConstructTestForTrainingData()
233233
{
234234
return new BinaryClassificationTest(ConstructScoreTracker(TrainSet), _trainSetLabels, _sigmoidParameter);
235235
}
236236

237-
internal override void InitializeTests()
237+
private protected override void InitializeTests()
238238
{
239239
//Always compute training L1/L2 errors
240240
TrainTest = new BinaryClassificationTest(ConstructScoreTracker(TrainSet), _trainSetLabels, _sigmoidParameter);

src/Microsoft.ML.FastTree/FastTreeRanking.cs

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ protected override void CheckLabelCompatible(SchemaShape.Column labelCol)
106106
if (!labelCol.IsKey && labelCol.ItemType != NumberDataViewType.Single)
107107
error();
108108
}
109-
internal override float GetMaxLabel()
109+
private protected override float GetMaxLabel()
110110
{
111111
return GetLabelGains().Length - 1;
112112
}
@@ -143,7 +143,7 @@ private Double[] GetLabelGains()
143143
}
144144
}
145145

146-
internal override void CheckArgs(IChannel ch)
146+
private protected override void CheckArgs(IChannel ch)
147147
{
148148
if (!string.IsNullOrEmpty(Args.CustomGains))
149149
{
@@ -173,7 +173,7 @@ internal override void CheckArgs(IChannel ch)
173173
base.CheckArgs(ch);
174174
}
175175

176-
internal override void Initialize(IChannel ch)
176+
private protected override void Initialize(IChannel ch)
177177
{
178178
base.Initialize(ch);
179179
if (Args.CompressEnsemble)
@@ -183,7 +183,7 @@ internal override void Initialize(IChannel ch)
183183
}
184184
}
185185

186-
internal override ObjectiveFunctionBase ConstructObjFunc(IChannel ch)
186+
private protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch)
187187
{
188188
return new LambdaRankObjectiveFunction(TrainSet, TrainSet.Ratings, Args, ParallelTraining);
189189
}
@@ -199,22 +199,22 @@ private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(
199199
return optimizationAlgorithm;
200200
}
201201

202-
internal override BaggingProvider CreateBaggingProvider()
202+
private protected override BaggingProvider CreateBaggingProvider()
203203
{
204204
Host.Assert(Args.BaggingSize > 0);
205205
return new RankingBaggingProvider(TrainSet, Args.NumLeaves, Args.RngSeed, Args.BaggingTrainFraction);
206206
}
207207

208-
internal override void PrepareLabels(IChannel ch)
208+
private protected override void PrepareLabels(IChannel ch)
209209
{
210210
}
211211

212-
internal override Test ConstructTestForTrainingData()
212+
private protected override Test ConstructTestForTrainingData()
213213
{
214214
return new NdcgTest(ConstructScoreTracker(TrainSet), TrainSet.Ratings, Args.SortingAlgorithm);
215215
}
216216

217-
internal override void InitializeTests()
217+
private protected override void InitializeTests()
218218
{
219219
if (Args.TestFrequency != int.MaxValue)
220220
{
@@ -280,7 +280,7 @@ private void AddFullTests()
280280
}
281281
}
282282

283-
internal override void PrintIterationMessage(IChannel ch, IProgressChannel pch)
283+
private protected override void PrintIterationMessage(IChannel ch, IProgressChannel pch)
284284
{
285285
// REVIEW: Shift to using progress channels to report this information.
286286
#if OLD_TRACE
@@ -316,7 +316,7 @@ internal override void PrintIterationMessage(IChannel ch, IProgressChannel pch)
316316
#endif
317317
}
318318

319-
internal override void ComputeTests()
319+
private protected override void ComputeTests()
320320
{
321321
if (_firstTestSetHistory != null)
322322
_firstTestSetHistory.ComputeTests();
@@ -328,7 +328,7 @@ internal override void ComputeTests()
328328
PruningTest.ComputeTests();
329329
}
330330

331-
internal override string GetTestGraphLine()
331+
private protected override string GetTestGraphLine()
332332
{
333333
StringBuilder lineBuilder = new StringBuilder();
334334

@@ -361,7 +361,7 @@ internal override string GetTestGraphLine()
361361
return lineBuilder.ToString();
362362
}
363363

364-
internal override void Train(IChannel ch)
364+
private protected override void Train(IChannel ch)
365365
{
366366
base.Train(ch);
367367
// Print final last iteration.
@@ -438,7 +438,7 @@ private Test CreateFirstTestSetTest()
438438
/// Get the header of test graph
439439
/// </summary>
440440
/// <returns>Test graph header</returns>
441-
internal override string GetTestGraphHeader()
441+
private protected override string GetTestGraphHeader()
442442
{
443443
StringBuilder headerBuilder = new StringBuilder("Eval:\tFileName\tNDCG@1\tNDCG@2\tNDCG@3\tNDCG@4\tNDCG@5\tNDCG@6\tNDCG@7\tNDCG@8\tNDCG@9\tNDCG@10");
444444

@@ -1127,11 +1127,11 @@ private static VersionInfo GetVersionInfo()
11271127
loaderAssemblyName: typeof(FastTreeRankingModelParameters).Assembly.FullName);
11281128
}
11291129

1130-
internal override uint VerNumFeaturesSerialized => 0x00010002;
1130+
private protected override uint VerNumFeaturesSerialized => 0x00010002;
11311131

1132-
internal override uint VerDefaultValueSerialized => 0x00010004;
1132+
private protected override uint VerDefaultValueSerialized => 0x00010004;
11331133

1134-
internal override uint VerCategoricalSplitSerialized => 0x00010005;
1134+
private protected override uint VerCategoricalSplitSerialized => 0x00010005;
11351135

11361136
internal FastTreeRankingModelParameters(IHostEnvironment env, InternalTreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
11371137
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)

src/Microsoft.ML.FastTree/FastTreeRegression.cs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ private protected override FastTreeRegressionModelParameters TrainModelCore(Trai
101101
return new FastTreeRegressionModelParameters(Host, TrainedEnsemble, FeatureCount, InnerArgs);
102102
}
103103

104-
internal override void CheckArgs(IChannel ch)
104+
private protected override void CheckArgs(IChannel ch)
105105
{
106106
Contracts.AssertValue(ch);
107107

@@ -116,7 +116,7 @@ private static SchemaShape.Column MakeLabelColumn(string labelColumn)
116116
return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.Single, false);
117117
}
118118

119-
internal override ObjectiveFunctionBase ConstructObjFunc(IChannel ch)
119+
private protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch)
120120
{
121121
return new ObjectiveImpl(TrainSet, Args);
122122
}
@@ -152,11 +152,11 @@ internal static float[] GetDatasetRegressionLabels(Dataset set)
152152
return dlabels.Select(x => (float)x).ToArray(dlabels.Length);
153153
}
154154

155-
internal override void PrepareLabels(IChannel ch)
155+
private protected override void PrepareLabels(IChannel ch)
156156
{
157157
}
158158

159-
internal override Test ConstructTestForTrainingData()
159+
private protected override Test ConstructTestForTrainingData()
160160
{
161161
return new RegressionTest(ConstructScoreTracker(TrainSet));
162162
}
@@ -226,7 +226,7 @@ protected virtual void AddFullNDCGTests()
226226
}
227227
#endif
228228

229-
internal override void InitializeTests()
229+
private protected override void InitializeTests()
230230
{
231231
// Initialize regression tests.
232232
if (Args.TestFrequency != int.MaxValue)
@@ -267,7 +267,7 @@ internal override void InitializeTests()
267267
}
268268
}
269269

270-
internal override void PrintIterationMessage(IChannel ch, IProgressChannel pch)
270+
private protected override void PrintIterationMessage(IChannel ch, IProgressChannel pch)
271271
{
272272
// REVIEW: Shift this to use progress channels.
273273
#if OLD_TRACING
@@ -307,7 +307,7 @@ internal override void PrintIterationMessage(IChannel ch, IProgressChannel pch)
307307
#endif
308308
}
309309

310-
internal override string GetTestGraphHeader()
310+
private protected override string GetTestGraphHeader()
311311
{
312312
StringBuilder headerBuilder = new StringBuilder("Eval:\tFileName\tNDCG@1\tNDCG@2\tNDCG@3\tNDCG@4\tNDCG@5\tNDCG@6\tNDCG@7\tNDCG@8\tNDCG@9\tNDCG@10");
313313

@@ -320,7 +320,7 @@ internal override string GetTestGraphHeader()
320320
return headerBuilder.ToString();
321321
}
322322

323-
internal override void ComputeTests()
323+
private protected override void ComputeTests()
324324
{
325325
if (_firstTestSetHistory != null)
326326
{
@@ -343,7 +343,7 @@ internal override void ComputeTests()
343343
}
344344
}
345345

346-
internal override string GetTestGraphLine()
346+
private protected override string GetTestGraphLine()
347347
{
348348
StringBuilder lineBuilder = new StringBuilder();
349349

@@ -371,7 +371,7 @@ internal override string GetTestGraphLine()
371371
return lineBuilder.ToString();
372372
}
373373

374-
internal override void Train(IChannel ch)
374+
private protected override void Train(IChannel ch)
375375
{
376376
base.Train(ch);
377377
// Print final last iteration.
@@ -462,11 +462,11 @@ private static VersionInfo GetVersionInfo()
462462
loaderAssemblyName: typeof(FastTreeRegressionModelParameters).Assembly.FullName);
463463
}
464464

465-
internal override uint VerNumFeaturesSerialized => 0x00010002;
465+
private protected override uint VerNumFeaturesSerialized => 0x00010002;
466466

467-
internal override uint VerDefaultValueSerialized => 0x00010004;
467+
private protected override uint VerDefaultValueSerialized => 0x00010004;
468468

469-
internal override uint VerCategoricalSplitSerialized => 0x00010005;
469+
private protected override uint VerCategoricalSplitSerialized => 0x00010005;
470470

471471
internal FastTreeRegressionModelParameters(IHostEnvironment env, InternalTreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
472472
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)

0 commit comments

Comments
 (0)