Skip to content

Commit c479a06

Browse files
committed
Internal-best-friend TreeEnsemble
1 parent ca08b44 commit c479a06

19 files changed

+48
-24
lines changed

src/Microsoft.ML.FastTree/FastTree.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
5454
{
5555
protected readonly TArgs Args;
5656
protected readonly bool AllowGC;
57-
protected TreeEnsemble TrainedEnsemble;
57+
[BestFriend]
58+
internal TreeEnsemble TrainedEnsemble;
5859
protected int FeatureCount;
5960
private protected RoleMappedData ValidData;
6061
/// <summary>
@@ -87,7 +88,8 @@ public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
8788
protected double[] InitValidScores;
8889
protected double[][] InitTestScores;
8990
//protected int Iteration;
90-
protected TreeEnsemble Ensemble;
91+
[BestFriend]
92+
internal TreeEnsemble Ensemble;
9193

9294
protected bool HasValidSet => ValidSet != null;
9395

@@ -2853,7 +2855,8 @@ public abstract class TreeEnsembleModelParameters :
28532855
/// </summary>
28542856
public FeatureContributionCalculator FeatureContributionCalculator => new FeatureContributionCalculator(this);
28552857

2856-
public TreeEnsembleModelParameters(IHostEnvironment env, string name, TreeEnsemble trainedEnsemble, int numFeatures, string innerArgs)
2858+
[BestFriend]
2859+
internal TreeEnsembleModelParameters(IHostEnvironment env, string name, TreeEnsemble trainedEnsemble, int numFeatures, string innerArgs)
28572860
: base(env, name)
28582861
{
28592862
Host.CheckValue(trainedEnsemble, nameof(trainedEnsemble));

src/Microsoft.ML.FastTree/FastTreeClassification.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ private static VersionInfo GetVersionInfo()
6969

7070
protected override uint VerCategoricalSplitSerialized => 0x00010005;
7171

72-
public FastTreeBinaryModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
72+
[BestFriend]
73+
internal FastTreeBinaryModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
7374
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
7475
{
7576
}

src/Microsoft.ML.FastTree/FastTreeRanking.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1130,7 +1130,8 @@ private static VersionInfo GetVersionInfo()
11301130

11311131
protected override uint VerCategoricalSplitSerialized => 0x00010005;
11321132

1133-
public FastTreeRankingModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
1133+
[BestFriend]
1134+
internal FastTreeRankingModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
11341135
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
11351136
{
11361137
}

src/Microsoft.ML.FastTree/FastTreeRegression.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,8 @@ private static VersionInfo GetVersionInfo()
466466

467467
protected override uint VerCategoricalSplitSerialized => 0x00010005;
468468

469-
public FastTreeRegressionModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
469+
[BestFriend]
470+
internal FastTreeRegressionModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
470471
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
471472
{
472473
}

src/Microsoft.ML.FastTree/FastTreeTweedie.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,8 @@ private static VersionInfo GetVersionInfo()
469469

470470
protected override uint VerCategoricalSplitSerialized => 0x00010003;
471471

472-
public FastTreeTweedieModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
472+
[BestFriend]
473+
internal FastTreeTweedieModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
473474
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
474475
{
475476
}

src/Microsoft.ML.FastTree/RandomForestClassification.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ private static VersionInfo GetVersionInfo()
7979
/// </summary>
8080
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
8181

82-
public FastForestClassificationModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
82+
[BestFriend]
83+
internal FastForestClassificationModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
8384
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
8485
{ }
8586

src/Microsoft.ML.FastTree/RandomForestRegression.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ private static VersionInfo GetVersionInfo()
5959

6060
protected override uint VerCategoricalSplitSerialized => 0x00010006;
6161

62-
public FastForestRegressionModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs, int samplesCount)
62+
[BestFriend]
63+
internal FastForestRegressionModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs, int samplesCount)
6364
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
6465
{
6566
_quantileSampleCount = samplesCount;

src/Microsoft.ML.FastTree/Training/BaggingProvider.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ public int GetBagCount(int numTrees, int bagSize)
7575
// Divides output values of leaves to bag count.
7676
// This brings back the final scores generated by model on a same
7777
// range as when we didn't use bagging
78-
public void ScaleEnsembleLeaves(int numTrees, int bagSize, TreeEnsemble ensemble)
78+
[BestFriend]
79+
internal void ScaleEnsembleLeaves(int numTrees, int bagSize, TreeEnsemble ensemble)
7980
{
8081
int bagCount = GetBagCount(numTrees, bagSize);
8182
for (int t = 0; t < ensemble.NumTrees; t++)

src/Microsoft.ML.FastTree/Training/EnsembleCompression/IEnsembleCompressor.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
namespace Microsoft.ML.Trainers.FastTree.Internal
66
{
7-
public interface IEnsembleCompressor<TLabel>
7+
[BestFriend]
8+
internal interface IEnsembleCompressor<TLabel>
89
{
910
void Initialize(int numTrees, Dataset trainSet, TLabel[] labels, int randomSeed);
1011

src/Microsoft.ML.FastTree/Training/EnsembleCompression/LassoBasedEnsembleCompressor.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ namespace Microsoft.ML.Trainers.FastTree.Internal
1515
/// https://www-stat.stanford.edu/~hastie/Papers/glmnet.pdf
1616
/// </summary>
1717
/// <remarks>Author was Yasser Ganjisaffar during his internship.</remarks>
18-
public class LassoBasedEnsembleCompressor : IEnsembleCompressor<short>
18+
[BestFriend]
19+
internal class LassoBasedEnsembleCompressor : IEnsembleCompressor<short>
1920
{
2021
// This module shouldn't consume more than 4GB of memory
2122
private const long MaxAvailableMemory = 4L * 1024 * 1024 * 1024;
@@ -533,7 +534,8 @@ private unsafe void LoadTargets(double[] trainScores, int bestIteration)
533534
}
534535
}
535536

536-
public bool Compress(IChannel ch, TreeEnsemble ensemble, double[] trainScores, int bestIteration, int maxTreesAfterCompression)
537+
[BestFriend]
538+
bool IEnsembleCompressor<short>.Compress(IChannel ch, TreeEnsemble ensemble, double[] trainScores, int bestIteration, int maxTreesAfterCompression)
537539
{
538540
LoadTargets(trainScores, bestIteration);
539541

@@ -551,7 +553,8 @@ public bool Compress(IChannel ch, TreeEnsemble ensemble, double[] trainScores, i
551553
return true;
552554
}
553555

554-
public TreeEnsemble GetCompressedEnsemble()
556+
[BestFriend]
557+
TreeEnsemble IEnsembleCompressor<short>.GetCompressedEnsemble()
555558
{
556559
return _compressedEnsemble;
557560
}

src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/AcceleratedGradientDescent.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ namespace Microsoft.ML.Trainers.FastTree.Internal
77
//Accelerated gradient descent score tracker
88
public class AcceleratedGradientDescent : GradientDescent
99
{
10-
public AcceleratedGradientDescent(TreeEnsemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
10+
[BestFriend]
11+
internal AcceleratedGradientDescent(TreeEnsemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
1112
: base(ensemble, trainData, initTrainScores, gradientWrapper)
1213
{
1314
UseFastTrainingScoresUpdate = false;

src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/ConjugateGradientDescent.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ public class ConjugateGradientDescent : GradientDescent
1111
private double[] _currentGradient;
1212
private double[] _currentDk;
1313

14-
public ConjugateGradientDescent(TreeEnsemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
14+
[BestFriend]
15+
internal ConjugateGradientDescent(TreeEnsemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
1516
: base(ensemble, trainData, initTrainScores, gradientWrapper)
1617
{
1718
_currentDk = new double[trainData.NumDocs];

src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/GradientDescent.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ public class GradientDescent : OptimizationAlgorithm
2121
private double[] _droppedScores;
2222
private double[] _scores;
2323

24-
public GradientDescent(TreeEnsemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
24+
[BestFriend]
25+
internal GradientDescent(TreeEnsemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
2526
: base(ensemble, trainData, initTrainScores)
2627
{
2728
_gradientWrapper = gradientWrapper;

src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/NoOptimizationAlgorithm.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ public class RandomForestOptimizer : GradientDescent
1212
{
1313
private IGradientAdjuster _gradientWrapper;
1414
// REVIEW: When the FastTree appliation is decoupled with tree learner and boosting logic, this class should be removed.
15-
public RandomForestOptimizer(TreeEnsemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
15+
[BestFriend]
16+
internal RandomForestOptimizer(TreeEnsemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
1617
: base(ensemble, trainData, initTrainScores, gradientWrapper)
1718
{
1819
_gradientWrapper = gradientWrapper;

src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/OptimizationAlgorithm.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ public abstract class OptimizationAlgorithm
2525
public delegate void PreScoreUpdateHandler(IChannel ch);
2626
public PreScoreUpdateHandler PreScoreUpdateEvent;
2727

28-
public TreeEnsemble Ensemble;
28+
[BestFriend]
29+
internal TreeEnsemble Ensemble;
2930

3031
public ScoreTracker TrainingScores;
3132
public List<ScoreTracker> TrackedScores;
@@ -36,7 +37,8 @@ public abstract class OptimizationAlgorithm
3637
public Random DropoutRng;
3738
public bool UseFastTrainingScoresUpdate;
3839

39-
public OptimizationAlgorithm(TreeEnsemble ensemble, Dataset trainData, double[] initTrainScores)
40+
[BestFriend]
41+
internal OptimizationAlgorithm(TreeEnsemble ensemble, Dataset trainData, double[] initTrainScores)
4042
{
4143
Ensemble = ensemble;
4244
TrainingScores = ConstructScoreTracker("train", trainData, initTrainScores);

src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsemble.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ internal TreeEnsembleView(TreeEnsemble treeEnsemble)
4545
}
4646
}
4747

48-
public class TreeEnsemble
48+
[BestFriend]
49+
internal class TreeEnsemble
4950
{
5051
/// <summary>
5152
/// String appended to the text representation of <see cref="TreeEnsemble"/>. This is mainly used in <see cref="ToTreeEnsembleIni"/>.

src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ private static VersionInfo GetVersionInfo()
5555
protected override uint VerCategoricalSplitSerialized => 0x00010005;
5656
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
5757

58-
public LightGbmBinaryModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
58+
[BestFriend]
59+
internal LightGbmBinaryModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
5960
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
6061
{
6162
}

src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ private static VersionInfo GetVersionInfo()
5050
protected override uint VerCategoricalSplitSerialized => 0x00010005;
5151
public override PredictionKind PredictionKind => PredictionKind.Ranking;
5252

53-
public LightGbmRankingModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
53+
[BestFriend]
54+
internal LightGbmRankingModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
5455
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
5556
{
5657
}

src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ private static VersionInfo GetVersionInfo()
5050
protected override uint VerCategoricalSplitSerialized => 0x00010005;
5151
public override PredictionKind PredictionKind => PredictionKind.Regression;
5252

53-
public LightGbmRegressionModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
53+
[BestFriend]
54+
internal LightGbmRegressionModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
5455
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
5556
{
5657
}

0 commit comments

Comments
 (0)