From b1b1823f958badbff1e133a525a123ab9a2f84ae Mon Sep 17 00:00:00 2001 From: "REDMOND\\nakazmi" Date: Wed, 5 Dec 2018 17:36:39 -0800 Subject: [PATCH 1/6] Internalize and explicitly implement ICanSAveInIniFormat, ICanSaveInSourceCode, ICanSaveSummary, ICanSaveSummaryInKeyValuePairs, and ICanGetSummaryAsIRow --- .../Dirty/PredictorInterfaces.cs | 15 +- .../Prediction/Calibrator.cs | 8 +- .../Scorers/SchemaBindablePredictorWrapper.cs | 2 +- src/Microsoft.ML.Ensemble/PipelineEnsemble.cs | 4 +- .../Trainer/EnsemblePredictorBase.cs | 2 +- src/Microsoft.ML.FastTree/FastTree.cs | 16 +- src/Microsoft.ML.FastTree/GamTrainer.cs | 2 +- .../OlsLinearRegression.cs | 2 +- src/Microsoft.ML.Legacy/CSharpApi.cs | 342 ++++++++++++++++++ src/Microsoft.ML.PCA/PcaTrainer.cs | 2 +- .../Standard/LinearPredictor.cs | 38 +- .../MulticlassLogisticRegression.cs | 12 +- .../Standard/MultiClass/Ova.cs | 2 +- .../SentimentPredictionTests.cs | 5 +- 14 files changed, 406 insertions(+), 46 deletions(-) diff --git a/src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs b/src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs index a1baf261c2..9b49be96b0 100644 --- a/src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs +++ b/src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs @@ -101,7 +101,8 @@ internal interface ICanSaveInTextFormat /// /// Predictors that can output themselves in the Bing ini format. /// - public interface ICanSaveInIniFormat + [BestFriend] + internal interface ICanSaveInIniFormat { void SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator = null); } @@ -109,7 +110,8 @@ public interface ICanSaveInIniFormat /// /// Predictors that can output Summary. /// - public interface ICanSaveSummary + [BestFriend] + internal interface ICanSaveSummary { void SaveSummary(TextWriter writer, RoleMappedSchema schema); } @@ -119,7 +121,8 @@ public interface ICanSaveSummary /// The content of value 'object' can be any type such as integer, float, string or an array of them. /// It is up the caller to check and decide how to consume the values. /// - public interface ICanGetSummaryInKeyValuePairs + [BestFriend] + internal interface ICanGetSummaryInKeyValuePairs { /// /// Gets model summary including model statistics (if exists) in key value pairs. @@ -127,7 +130,8 @@ public interface ICanGetSummaryInKeyValuePairs IList> GetSummaryInKeyValuePairs(RoleMappedSchema schema); } - public interface ICanGetSummaryAsIRow + [BestFriend] + internal interface ICanGetSummaryAsIRow { Row GetSummaryIRowOrNull(RoleMappedSchema schema); @@ -142,7 +146,8 @@ public interface ICanGetSummaryAsIDataView /// /// Predictors that can output themselves in C#/C++ code. /// - public interface ICanSaveInSourceCode + [BestFriend] + internal interface ICanSaveInSourceCode { void SaveAsCode(TextWriter writer, RoleMappedSchema schema); } diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index 45860d2e39..45ccbf966a 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -152,7 +152,7 @@ protected CalibratedPredictorBase(IHostEnvironment env, string name, IPredictorP Calibrator = calibrator; } - public void SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator = null) + void ICanSaveInIniFormat.SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator) { Host.Check(calibrator == null, "Too many calibrators."); var saver = SubPredictor as ICanSaveInIniFormat; @@ -167,7 +167,7 @@ void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) saver.SaveAsText(writer, schema); } - public void SaveAsCode(TextWriter writer, RoleMappedSchema schema) + void ICanSaveInSourceCode.SaveAsCode(TextWriter writer, RoleMappedSchema schema) { // REVIEW: What about the calibrator? var saver = SubPredictor as ICanSaveInSourceCode; @@ -175,7 +175,7 @@ public void SaveAsCode(TextWriter writer, RoleMappedSchema schema) saver.SaveAsCode(writer, schema); } - public void SaveSummary(TextWriter writer, RoleMappedSchema schema) + void ICanSaveSummary.SaveSummary(TextWriter writer, RoleMappedSchema schema) { // REVIEW: What about the calibrator? var saver = SubPredictor as ICanSaveSummary; @@ -184,7 +184,7 @@ public void SaveSummary(TextWriter writer, RoleMappedSchema schema) } /// - public IList> GetSummaryInKeyValuePairs(RoleMappedSchema schema) + IList> ICanGetSummaryInKeyValuePairs.GetSummaryInKeyValuePairs(RoleMappedSchema schema) { // REVIEW: What about the calibrator? var saver = SubPredictor as ICanGetSummaryInKeyValuePairs; diff --git a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs index 279d2a4ad0..f97e3516a3 100644 --- a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs +++ b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs @@ -171,7 +171,7 @@ private ValueGetter GetValueGetter(Row input, int colSrc) }; } - public void SaveSummary(TextWriter writer, RoleMappedSchema schema) + void ICanSaveSummary.SaveSummary(TextWriter writer, RoleMappedSchema schema) { var summarySaver = Predictor as ICanSaveSummary; if (summarySaver == null) diff --git a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs index dcce4f0018..e9d0902456 100644 --- a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs +++ b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs @@ -557,7 +557,7 @@ public static SchemaBindablePipelineEnsembleBase Create(IHostEnvironment env, Mo public abstract ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema); - public void SaveSummary(TextWriter writer, RoleMappedSchema schema) + void ICanSaveSummary.SaveSummary(TextWriter writer, RoleMappedSchema schema) { for (int i = 0; i < PredictorModels.Length; i++) { @@ -688,7 +688,7 @@ private static bool AreEqual(in VBuffer v1, in VBuffer v2) /// - If neither of those interfaces are implemented then the value is a string containing the name of the type of model. /// /// - public IList> GetSummaryInKeyValuePairs(RoleMappedSchema schema) + IList> ICanGetSummaryInKeyValuePairs.GetSummaryInKeyValuePairs(RoleMappedSchema schema) { Host.CheckValueOrNull(schema); diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictorBase.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictorBase.cs index 72a81b12e3..4b9ddb89dd 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictorBase.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/EnsemblePredictorBase.cs @@ -144,7 +144,7 @@ void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) /// /// Saves the model summary /// - public void SaveSummary(TextWriter writer, RoleMappedSchema schema) + void ICanSaveSummary.SaveSummary(TextWriter writer, RoleMappedSchema schema) { for (int i = 0; i < Models.Length; i++) { diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 8676331bd2..5e3ee93891 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -2963,7 +2963,7 @@ private void FeatureContributionMap(in VBuffer src, ref VBuffer ds /// /// write out a C# representation of the ensemble /// - public void SaveAsCode(TextWriter writer, RoleMappedSchema schema) + void ICanSaveInSourceCode.SaveAsCode(TextWriter writer, RoleMappedSchema schema) { Host.CheckValueOrNull(schema); SaveEnsembleAsCode(writer, schema); @@ -2976,13 +2976,13 @@ void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) { Host.CheckValue(writer, nameof(writer)); Host.CheckValueOrNull(schema); - SaveAsIni(writer, schema); + ((ICanSaveInIniFormat)this).SaveAsIni(writer, schema); } /// /// Output the INI model to a given writer /// - public void SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator = null) + void ICanSaveInIniFormat.SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator) { Host.CheckValue(writer, nameof(writer)); Host.CheckValue(schema, nameof(schema)); @@ -3156,12 +3156,12 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string return true; } - public void SaveSummary(TextWriter writer, RoleMappedSchema schema) + void ICanSaveSummary.SaveSummary(TextWriter writer, RoleMappedSchema schema) { writer.WriteLine(); writer.WriteLine("Per-feature gain summary for the boosted tree ensemble:"); - foreach (var pair in GetSummaryInKeyValuePairs(schema)) + foreach (var pair in ((ICanGetSummaryInKeyValuePairs)this).GetSummaryInKeyValuePairs(schema)) { Host.Assert(pair.Value is Double); writer.WriteLine("\t{0}\t{1}", pair.Key, (Double)pair.Value); @@ -3187,7 +3187,7 @@ private IEnumerable> GetSortedFeatureGains(RoleMapp } /// - public IList> GetSummaryInKeyValuePairs(RoleMappedSchema schema) + IList> ICanGetSummaryInKeyValuePairs.GetSummaryInKeyValuePairs(RoleMappedSchema schema) { List> results = new List>(); @@ -3309,7 +3309,7 @@ public int GetLeaf(int treeId, in VBuffer features, ref List path) return TrainedEnsemble.GetTreeAt(treeId).GetLeaf(in features, ref path); } - public Row GetSummaryIRowOrNull(RoleMappedSchema schema) + Row ICanGetSummaryAsIRow.GetSummaryIRowOrNull(RoleMappedSchema schema) { var names = default(VBuffer>); MetadataUtils.GetSlotNames(schema, RoleMappedSchema.ColumnRole.Feature, NumFeatures, ref names); @@ -3324,7 +3324,7 @@ public Row GetSummaryIRowOrNull(RoleMappedSchema schema) return MetadataUtils.MetadataAsRow(builder.GetMetadata()); } - public Row GetStatsIRowOrNull(RoleMappedSchema schema) + Row ICanGetSummaryAsIRow.GetStatsIRowOrNull(RoleMappedSchema schema) { return null; } diff --git a/src/Microsoft.ML.FastTree/GamTrainer.cs b/src/Microsoft.ML.FastTree/GamTrainer.cs index 7943af8835..15d599e237 100644 --- a/src/Microsoft.ML.FastTree/GamTrainer.cs +++ b/src/Microsoft.ML.FastTree/GamTrainer.cs @@ -1043,7 +1043,7 @@ void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) } } - public void SaveSummary(TextWriter writer, RoleMappedSchema schema) + void ICanSaveSummary.SaveSummary(TextWriter writer, RoleMappedSchema schema) { ((ICanSaveInTextFormat)this).SaveAsText(writer, schema); } diff --git a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs index c4fdd85e22..5648aa29e4 100644 --- a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs +++ b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs @@ -738,7 +738,7 @@ public static OlsLinearRegressionPredictor Create(IHostEnvironment env, ModelLoa return new OlsLinearRegressionPredictor(env, ctx); } - public override void SaveSummary(TextWriter writer, RoleMappedSchema schema) + private protected override void SaveSummary(TextWriter writer, RoleMappedSchema schema) { var names = default(VBuffer>); MetadataUtils.GetSlotNames(schema, RoleMappedSchema.ColumnRole.Feature, Weight.Length, ref names); diff --git a/src/Microsoft.ML.Legacy/CSharpApi.cs b/src/Microsoft.ML.Legacy/CSharpApi.cs index c7e5c3ff89..bcd35d4988 100644 --- a/src/Microsoft.ML.Legacy/CSharpApi.cs +++ b/src/Microsoft.ML.Legacy/CSharpApi.cs @@ -400,6 +400,20 @@ public void Add(Microsoft.ML.Legacy.Models.PAVCalibrator input, Microsoft.ML.Leg _jsonNodes.Add(Serialize("Models.PAVCalibrator", input, output)); } + [Obsolete] + public Microsoft.ML.Legacy.Models.PipelineSweeper.Output Add(Microsoft.ML.Legacy.Models.PipelineSweeper input) + { + var output = new Microsoft.ML.Legacy.Models.PipelineSweeper.Output(); + Add(input, output); + return output; + } + + [Obsolete] + public void Add(Microsoft.ML.Legacy.Models.PipelineSweeper input, Microsoft.ML.Legacy.Models.PipelineSweeper.Output output) + { + _jsonNodes.Add(Serialize("Models.PipelineSweeper", input, output)); + } + [Obsolete] public Microsoft.ML.Legacy.Models.PlattCalibrator.Output Add(Microsoft.ML.Legacy.Models.PlattCalibrator input) { @@ -498,6 +512,20 @@ public void Add(Microsoft.ML.Legacy.Models.Summarizer input, Microsoft.ML.Legacy _jsonNodes.Add(Serialize("Models.Summarizer", input, output)); } + [Obsolete] + public Microsoft.ML.Legacy.Models.SweepResultExtractor.Output Add(Microsoft.ML.Legacy.Models.SweepResultExtractor input) + { + var output = new Microsoft.ML.Legacy.Models.SweepResultExtractor.Output(); + Add(input, output); + return output; + } + + [Obsolete] + public void Add(Microsoft.ML.Legacy.Models.SweepResultExtractor input, Microsoft.ML.Legacy.Models.SweepResultExtractor.Output output) + { + _jsonNodes.Add(Serialize("Models.SweepResultExtractor", input, output)); + } + [Obsolete] public Microsoft.ML.Legacy.Models.TrainTestBinaryEvaluator.Output Add(Microsoft.ML.Legacy.Models.TrainTestBinaryEvaluator input) { @@ -4093,6 +4121,120 @@ public PAVCalibratorPipelineStep(Output output) } } + namespace Legacy.Models + { + + /// + /// AutoML pipeline sweeping optimzation macro. + /// + [Obsolete] + public sealed partial class PipelineSweeper + { + + + /// + /// The data to be used for training. + /// + [Obsolete] + public Var TrainingData { get; set; } = new Var(); + + /// + /// The data to be used for testing. + /// + [Obsolete] + public Var TestingData { get; set; } = new Var(); + + /// + /// The arguments for creating an AutoMlState component. + /// + [JsonConverter(typeof(ComponentSerializer))] + [Obsolete] + public AutoMlStateBase StateArguments { get; set; } + + /// + /// The stateful object conducting of the autoML search. + /// + [Obsolete] + public Var State { get; set; } = new Var(); + + /// + /// Number of candidate pipelines to retrieve each round. + /// + [Obsolete] + public int BatchSize { get; set; } + + /// + /// Output datasets from previous iteration of sweep. + /// + [Obsolete] + public ArrayVar CandidateOutputs { get; set; } = new ArrayVar(); + + /// + /// Column(s) to use as Role 'Label' + /// + [Obsolete] + public string[] LabelColumns { get; set; } + + /// + /// Column(s) to use as Role 'Group' + /// + [Obsolete] + public string[] GroupColumns { get; set; } + + /// + /// Column(s) to use as Role 'Weight' + /// + [Obsolete] + public string[] WeightColumns { get; set; } + + /// + /// Column(s) to use as Role 'Name' + /// + [Obsolete] + public string[] NameColumns { get; set; } + + /// + /// Column(s) to use as Role 'NumericFeature' + /// + [Obsolete] + public string[] NumericFeatureColumns { get; set; } + + /// + /// Column(s) to use as Role 'CategoricalFeature' + /// + [Obsolete] + public string[] CategoricalFeatureColumns { get; set; } + + /// + /// Column(s) to use as Role 'TextFeature' + /// + [Obsolete] + public string[] TextFeatureColumns { get; set; } + + /// + /// Column(s) to use as Role 'ImagePath' + /// + [Obsolete] + public string[] ImagePathColumns { get; set; } + + + [Obsolete] + public sealed class Output + { + /// + /// Stateful autoML object, keeps track of where the search in progress. + /// + public Var State { get; set; } = new Var(); + + /// + /// Results of the sweep, including pipelines (as graph strings), IDs, and metric values. + /// + public Var Results { get; set; } = new Var(); + + } + } + } + namespace Legacy.Models { @@ -4528,6 +4670,41 @@ public sealed class Output } } + namespace Legacy.Models + { + + /// + /// Extracts the sweep result. + /// + [Obsolete] + public sealed partial class SweepResultExtractor + { + + + /// + /// The stateful object conducting of the autoML search. + /// + [Obsolete] + public Var State { get; set; } = new Var(); + + + [Obsolete] + public sealed class Output + { + /// + /// Stateful autoML object, keeps track of where the search in progress. + /// + public Var State { get; set; } = new Var(); + + /// + /// Results of the sweep, including pipelines (as graph strings), IDs, and metric values. + /// + public Var Results { get; set; } = new Var(); + + } + } + } + namespace Legacy.Models { @@ -20212,6 +20389,150 @@ public WordTokenizerPipelineStep(Output output) namespace Runtime { + [Obsolete] + public abstract class AutoMlEngine : ComponentKind {} + + + + /// + /// AutoML engine that returns learners with default settings. + /// + [Obsolete] + public sealed class DefaultsAutoMlEngine : AutoMlEngine + { + [Obsolete] + internal override string ComponentName => "Defaults"; + } + + + + /// + /// AutoML engine that consists of distinct, hierarchical stages of operation. + /// + [Obsolete] + public sealed class RocketAutoMlEngine : AutoMlEngine + { + /// + /// Number of learners to retain for second stage. + /// + [Obsolete] + public int TopKLearners { get; set; } = 2; + + /// + /// Number of trials for retained second stage learners. + /// + [Obsolete] + public int SecondRoundTrialsPerLearner { get; set; } = 5; + + /// + /// Use random initialization only. + /// + [Obsolete] + public bool RandomInitialization { get; set; } = false; + + /// + /// Number of initilization pipelines, used for random initialization only. + /// + [Obsolete] + public int NumInitializationPipelines { get; set; } = 20; + + [Obsolete] + internal override string ComponentName => "Rocket"; + } + + + + /// + /// AutoML engine using uniform random sampling. + /// + [Obsolete] + public sealed class UniformRandomAutoMlEngine : AutoMlEngine + { + [Obsolete] + internal override string ComponentName => "UniformRandom"; + } + + [Obsolete] + public abstract class AutoMlStateBase : ComponentKind {} + + [Obsolete] + public enum PipelineSweeperSupportedMetricsMetrics + { + Auc = 0, + AccuracyMicro = 1, + AccuracyMacro = 2, + L1 = 3, + L2 = 4, + F1 = 5, + AuPrc = 6, + TopKAccuracy = 7, + Rms = 8, + LossFn = 9, + RSquared = 10, + LogLoss = 11, + LogLossReduction = 12, + Ndcg = 13, + Dcg = 14, + PositivePrecision = 15, + PositiveRecall = 16, + NegativePrecision = 17, + NegativeRecall = 18, + DrAtK = 19, + DrAtPFpr = 20, + DrAtNumPos = 21, + NumAnomalies = 22, + ThreshAtK = 23, + ThreshAtP = 24, + ThreshAtNumPos = 25, + Nmi = 26, + AvgMinScore = 27, + Dbi = 28 + } + + + + /// + /// State of an AutoML search and search space. + /// + [Obsolete] + public sealed class AutoMlStateAutoMlStateBase : AutoMlStateBase + { + /// + /// Supported metric for evaluator. + /// + [Obsolete] + public PipelineSweeperSupportedMetricsMetrics Metric { get; set; } = PipelineSweeperSupportedMetricsMetrics.Auc; + + /// + /// AutoML engine (pipeline optimizer) that generates next candidates. + /// + [JsonConverter(typeof(ComponentSerializer))] + [Obsolete] + public AutoMlEngine Engine { get; set; } + + /// + /// Kind of trainer for task, such as binary classification trainer, multiclass trainer, etc. + /// + [Obsolete] + public Microsoft.ML.Legacy.Models.MacroUtilsTrainerKinds TrainerKind { get; set; } = Microsoft.ML.Legacy.Models.MacroUtilsTrainerKinds.SignatureBinaryClassifierTrainer; + + /// + /// Arguments for creating terminator, which determines when to stop search. + /// + [JsonConverter(typeof(ComponentSerializer))] + [Obsolete] + public SearchTerminator TerminatorArgs { get; set; } + + /// + /// Learner set to sweep over (if available). + /// + [Obsolete] + public string[] RequestedLearners { get; set; } + + [Obsolete] + internal override string ComponentName => "AutoMlState"; + } + [Obsolete] public abstract class BoosterParameterFunction : ComponentKind {} @@ -23313,6 +23634,27 @@ public sealed class SquaredLossSDCARegressionLossFunction : SDCARegressionLossFu internal override string ComponentName => "SquaredLoss"; } + [Obsolete] + public abstract class SearchTerminator : ComponentKind {} + + + + /// + /// Terminators a sweep based on total number of iterations. + /// + [Obsolete] + public sealed class IterationLimitedSearchTerminator : SearchTerminator + { + /// + /// Total number of iterations. + /// + [Obsolete] + public int FinalHistoryLength { get; set; } + + [Obsolete] + internal override string ComponentName => "IterationLimited"; + } + [Obsolete] public abstract class StopWordsRemover : ComponentKind {} diff --git a/src/Microsoft.ML.PCA/PcaTrainer.cs b/src/Microsoft.ML.PCA/PcaTrainer.cs index 7740d2f6e6..ecb8ac79c3 100644 --- a/src/Microsoft.ML.PCA/PcaTrainer.cs +++ b/src/Microsoft.ML.PCA/PcaTrainer.cs @@ -498,7 +498,7 @@ public static PcaPredictor Create(IHostEnvironment env, ModelLoadContext ctx) return new PcaPredictor(env, ctx); } - public void SaveSummary(TextWriter writer, RoleMappedSchema schema) + void ICanSaveSummary.SaveSummary(TextWriter writer, RoleMappedSchema schema) { ((ICanSaveInTextFormat)this).SaveAsText(writer, schema); } diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs index 538fb51c65..368b0c0189 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs @@ -344,7 +344,7 @@ void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) SaveSummary(writer, schema); } - public void SaveAsCode(TextWriter writer, RoleMappedSchema schema) + void ICanSaveInSourceCode.SaveAsCode(TextWriter writer, RoleMappedSchema schema) { Host.CheckValue(writer, nameof(writer)); Host.CheckValue(schema, nameof(schema)); @@ -353,9 +353,13 @@ public void SaveAsCode(TextWriter writer, RoleMappedSchema schema) LinearPredictorUtils.SaveAsCode(writer, in weights, Bias, schema); } - public abstract void SaveSummary(TextWriter writer, RoleMappedSchema schema); + [BestFriend] + private protected abstract void SaveSummary(TextWriter writer, RoleMappedSchema schema); + + void ICanSaveSummary.SaveSummary(TextWriter writer, RoleMappedSchema schema) => SaveSummary(writer, schema); - public virtual Row GetSummaryIRowOrNull(RoleMappedSchema schema) + [BestFriend] + private protected virtual Row GetSummaryIRowOrNull(RoleMappedSchema schema) { var names = default(VBuffer>); MetadataUtils.GetSlotNames(schema, RoleMappedSchema.ColumnRole.Feature, Weight.Length, ref names); @@ -368,9 +372,17 @@ public virtual Row GetSummaryIRowOrNull(RoleMappedSchema schema) return MetadataUtils.MetadataAsRow(builder.GetMetadata()); } - public virtual Row GetStatsIRowOrNull(RoleMappedSchema schema) => null; + Row ICanGetSummaryAsIRow.GetSummaryIRowOrNull(RoleMappedSchema schema) => GetSummaryIRowOrNull(schema); + + [BestFriend] + private protected virtual Row GetStatsIRowOrNull(RoleMappedSchema schema) => null; + + Row ICanGetSummaryAsIRow.GetStatsIRowOrNull(RoleMappedSchema schema) => GetStatsIRowOrNull(schema); + + [BestFriend] + private protected abstract void SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator = null); - public abstract void SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator = null); + void ICanSaveInIniFormat.SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator) => SaveAsIni(writer, schema, calibrator); public virtual void GetFeatureWeights(ref VBuffer weights) { @@ -487,7 +499,7 @@ public IParameterMixer CombineParameters(IList> mo return new LinearBinaryPredictor(Host, in weights, bias); } - public override void SaveSummary(TextWriter writer, RoleMappedSchema schema) + private protected override void SaveSummary(TextWriter writer, RoleMappedSchema schema) { Host.CheckValue(schema, nameof(schema)); @@ -500,7 +512,7 @@ public override void SaveSummary(TextWriter writer, RoleMappedSchema schema) } /// - public IList> GetSummaryInKeyValuePairs(RoleMappedSchema schema) + IList> ICanGetSummaryInKeyValuePairs.GetSummaryInKeyValuePairs(RoleMappedSchema schema) { Host.CheckValue(schema, nameof(schema)); @@ -511,7 +523,7 @@ public IList> GetSummaryInKeyValuePairs(RoleMappedS return results; } - public override Row GetStatsIRowOrNull(RoleMappedSchema schema) + private protected override Row GetStatsIRowOrNull(RoleMappedSchema schema) { if (_stats == null) return null; @@ -521,7 +533,7 @@ public override Row GetStatsIRowOrNull(RoleMappedSchema schema) return MetadataUtils.MetadataAsRow(meta); } - public override void SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator = null) + private protected override void SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator = null) { Host.CheckValue(writer, nameof(writer)); Host.CheckValue(schema, nameof(schema)); @@ -553,7 +565,7 @@ public override PredictionKind PredictionKind /// /// Output the INI model to a given writer /// - public override void SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator) + private protected override void SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator) { if (calibrator != null) throw Host.ExceptNotImpl("Saving calibrators is not implemented yet."); @@ -617,7 +629,7 @@ private protected override void SaveCore(ModelSaveContext ctx) ctx.SetVersionInfo(GetVersionInfo()); } - public override void SaveSummary(TextWriter writer, RoleMappedSchema schema) + private protected override void SaveSummary(TextWriter writer, RoleMappedSchema schema) { Host.CheckValue(writer, nameof(writer)); Host.CheckValue(schema, nameof(schema)); @@ -640,7 +652,7 @@ public IParameterMixer CombineParameters(IList> mo } /// - public IList> GetSummaryInKeyValuePairs(RoleMappedSchema schema) + IList> ICanGetSummaryInKeyValuePairs.GetSummaryInKeyValuePairs(RoleMappedSchema schema) { Host.CheckValue(schema, nameof(schema)); @@ -698,7 +710,7 @@ protected override Float Score(in VBuffer src) return MathUtils.ExpSlow(base.Score(in src)); } - public override void SaveSummary(TextWriter writer, RoleMappedSchema schema) + private protected override void SaveSummary(TextWriter writer, RoleMappedSchema schema) { Host.CheckValue(writer, nameof(writer)); Host.CheckValue(schema, nameof(schema)); diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs index a114c2bb18..64e0d0c099 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs @@ -781,7 +781,7 @@ void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) { writer.WriteLine(nameof(MulticlassLogisticRegression) + " bias and non-zero weights"); - foreach (var namedValues in GetSummaryInKeyValuePairs(schema)) + foreach (var namedValues in ((ICanGetSummaryInKeyValuePairs)this).GetSummaryInKeyValuePairs(schema)) { Host.Assert(namedValues.Value is float); writer.WriteLine("\t{0}\t{1}", namedValues.Key, (float)namedValues.Value); @@ -792,7 +792,7 @@ void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) } /// - public IList> GetSummaryInKeyValuePairs(RoleMappedSchema schema) + IList> ICanGetSummaryInKeyValuePairs.GetSummaryInKeyValuePairs(RoleMappedSchema schema) { Host.CheckValueOrNull(schema); @@ -832,7 +832,7 @@ public IList> GetSummaryInKeyValuePairs(RoleMappedS /// /// Output the text model to a given writer /// - public void SaveAsCode(TextWriter writer, RoleMappedSchema schema) + void ICanSaveInSourceCode.SaveAsCode(TextWriter writer, RoleMappedSchema schema) { Host.CheckValue(writer, nameof(writer)); Host.CheckValueOrNull(schema); @@ -851,7 +851,7 @@ public void SaveAsCode(TextWriter writer, RoleMappedSchema schema) writer.WriteLine("output[{0}] = Math.Exp(scores[{0}] - softmax);", c); } - public void SaveSummary(TextWriter writer, RoleMappedSchema schema) + void ICanSaveSummary.SaveSummary(TextWriter writer, RoleMappedSchema schema) { ((ICanSaveInTextFormat)this).SaveAsText(writer, schema); } @@ -985,12 +985,12 @@ public IDataView GetSummaryDataView(RoleMappedSchema schema) return bldr.GetDataView(); } - public Row GetSummaryIRowOrNull(RoleMappedSchema schema) + Row ICanGetSummaryAsIRow.GetSummaryIRowOrNull(RoleMappedSchema schema) { return null; } - public Row GetStatsIRowOrNull(RoleMappedSchema schema) + Row ICanGetSummaryAsIRow.GetStatsIRowOrNull(RoleMappedSchema schema) { if (_stats == null) return null; diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs index 36a05b7cc2..398fce9015 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs @@ -365,7 +365,7 @@ ValueMapper IValueMapper.GetMapper() return (ValueMapper)(Delegate)_impl.GetMapper(); } - public void SaveAsCode(TextWriter writer, RoleMappedSchema schema) + void ICanSaveInSourceCode.SaveAsCode(TextWriter writer, RoleMappedSchema schema) { Host.CheckValue(writer, nameof(writer)); Host.CheckValue(schema, nameof(schema)); diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs index b04a360168..4a4fdc76b2 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs @@ -8,6 +8,7 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Trainers.FastTree; using Microsoft.ML.Runtime.Internal.Calibration; +using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Transforms.Text; using System.Linq; using Xunit; @@ -74,7 +75,7 @@ public void TrainAndPredictSentimentModelWithDirectionInstantiationTest() Assert.True(predictions.ElementAt(1).Sentiment); // Get feature importance based on feature gain during training - var summary = ((FeatureWeightsCalibratedPredictor)pred).GetSummaryInKeyValuePairs(trainRoles.Schema); + var summary = ((ICanGetSummaryInKeyValuePairs)pred).GetSummaryInKeyValuePairs(trainRoles.Schema); Assert.Equal(1.0, (double)summary[0].Value, 1); } @@ -148,7 +149,7 @@ public void TrainAndPredictSentimentModelWithDirectionInstantiationTestWithWordE Assert.True(predictions.ElementAt(1).Sentiment); // Get feature importance based on feature gain during training - var summary = ((FeatureWeightsCalibratedPredictor)pred).GetSummaryInKeyValuePairs(trainRoles.Schema); + var summary = ((ICanGetSummaryInKeyValuePairs)pred).GetSummaryInKeyValuePairs(trainRoles.Schema); Assert.Equal(1.0, (double)summary[0].Value, 1); } From 5bf7402dd42bda8bfd04e32036b215b4ce2a5eb6 Mon Sep 17 00:00:00 2001 From: "REDMOND\\nakazmi" Date: Thu, 6 Dec 2018 19:32:22 -0800 Subject: [PATCH 2/6] Internalize and explicitly implement IFeatureContributionMapper, IQuantileValueMapper, IQuantileRegressionPredictor. Rename FastTreePredictionWrapper to TreeEnsembleModelParameters and all descendants to XyzModelParameters --- .../Static/FastTreeRegression.cs | 2 +- .../Static/LightGBMRegression.cs | 2 +- .../Dirty/PredictorInterfaces.cs | 9 +- .../Prediction/Calibrator.cs | 4 +- src/Microsoft.ML.FastTree/FastTree.cs | 8 +- .../FastTreeClassification.cs | 24 +- src/Microsoft.ML.FastTree/FastTreeRanking.cs | 32 +- .../FastTreeRegression.cs | 32 +- src/Microsoft.ML.FastTree/FastTreeTweedie.cs | 32 +- .../RandomForestClassification.cs | 24 +- .../RandomForestRegression.cs | 38 +- .../TreeEnsemble/TreeEnsembleCombiner.cs | 10 +- .../TreeEnsembleFeaturizer.cs | 16 +- .../TreeTrainersStatic.cs | 4 +- .../AssemblyRegistration.cs | 2 +- src/Microsoft.ML.Legacy/CSharpApi.cs | 342 ------------------ .../LightGbmBinaryTrainer.cs | 18 +- .../LightGbmMulticlassTrainer.cs | 4 +- .../LightGbmRankingTrainer.cs | 32 +- .../LightGbmRegressionTrainer.cs | 32 +- src/Microsoft.ML.LightGBM/LightGbmStatic.cs | 4 +- .../Standard/LinearPredictor.cs | 2 +- .../Algorithms/SmacSweeper.cs | 16 +- .../UnitTests/TestEntryPoints.cs | 6 +- .../TestPredictors.cs | 2 +- .../Training.cs | 8 +- .../EnvironmentExtensions.cs | 2 +- 27 files changed, 184 insertions(+), 523 deletions(-) diff --git a/docs/samples/Microsoft.ML.Samples/Static/FastTreeRegression.cs b/docs/samples/Microsoft.ML.Samples/Static/FastTreeRegression.cs index 66ddc6772c..ba271b25ce 100644 --- a/docs/samples/Microsoft.ML.Samples/Static/FastTreeRegression.cs +++ b/docs/samples/Microsoft.ML.Samples/Static/FastTreeRegression.cs @@ -30,7 +30,7 @@ public static void FastTreeRegression() var data = reader.Read(dataFile); // The predictor that gets produced out of training - FastTreeRegressionPredictor pred = null; + FastTreeRegressionModelParameters pred = null; // Create the estimator var learningPipeline = reader.MakeNewEstimator() diff --git a/docs/samples/Microsoft.ML.Samples/Static/LightGBMRegression.cs b/docs/samples/Microsoft.ML.Samples/Static/LightGBMRegression.cs index ca257d864f..9ab90eaf95 100644 --- a/docs/samples/Microsoft.ML.Samples/Static/LightGBMRegression.cs +++ b/docs/samples/Microsoft.ML.Samples/Static/LightGBMRegression.cs @@ -30,7 +30,7 @@ public static void LightGbmRegression() var (trainData, testData) = mlContext.Regression.TrainTestSplit(data, testFraction: 0.1); // The predictor that gets produced out of training - LightGbmRegressionPredictor pred = null; + LightGbmRegressionModelParameters pred = null; // Create the estimator var learningPipeline = reader.MakeNewEstimator() diff --git a/src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs b/src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs index 9b49be96b0..f2d1ed6cfc 100644 --- a/src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs +++ b/src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs @@ -38,7 +38,8 @@ public interface IParameterMixer /// Predictor that can specialize for quantile regression. It will produce a , given /// an array of quantiles. /// - public interface IQuantileRegressionPredictor + [BestFriend] + internal interface IQuantileRegressionPredictor { ISchemaBindableMapper CreateMapper(Double[] quantiles); } @@ -59,7 +60,8 @@ public interface IDistribution } // REVIEW: How should this quantile stuff work? - public interface IQuantileValueMapper + [BestFriend] + internal interface IQuantileValueMapper { ValueMapper, VBuffer> GetMapper(Float[] quantiles); } @@ -183,7 +185,8 @@ public interface IPredictorWithFeatureWeights : IHaveFeatureWeights /// Interface for mapping input values to corresponding feature contributions. /// This interface is commonly implemented by predictors. /// - public interface IFeatureContributionMapper : IPredictor + [BestFriend] + internal interface IFeatureContributionMapper : IPredictor { /// /// Get a delegate for mapping Contributions to Features. diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index 45ccbf966a..716f6aaf40 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -258,7 +258,7 @@ public ValueMapper GetMapper() return (ValueMapper)(Delegate)del; } - public ValueMapper> GetFeatureContributionMapper(int top, int bottom, bool normalize) + ValueMapper> IFeatureContributionMapper.GetFeatureContributionMapper(int top, int bottom, bool normalize) { // REVIEW: checking this a bit too late. Host.Check(_featureContribution != null, "Predictor does not implement IFeatureContributionMapper"); @@ -682,7 +682,7 @@ public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema) return new Bound(Host, this, schema); } - public ValueMapper> GetFeatureContributionMapper(int top, int bottom, bool normalize) + ValueMapper> IFeatureContributionMapper.GetFeatureContributionMapper(int top, int bottom, bool normalize) { // REVIEW: checking this a bit too late. Host.Check(_featureContribution != null, "Predictor does not implement " + nameof(IFeatureContributionMapper)); diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 5e3ee93891..7a17421853 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -2794,7 +2794,7 @@ public Dataset GetCompatibleDataset(RoleMappedData data, PredictionKind kind, in } } - public abstract class FastTreePredictionWrapper : + public abstract class TreeEnsembleModelParameters : PredictorBase, IValueMapper, ICanSaveInTextFormat, @@ -2839,7 +2839,7 @@ public abstract class FastTreePredictionWrapper : bool ICanSavePfa.CanSavePfa => true; bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true; - protected FastTreePredictionWrapper(IHostEnvironment env, string name, TreeEnsemble trainedEnsemble, int numFeatures, string innerArgs) + protected TreeEnsembleModelParameters(IHostEnvironment env, string name, TreeEnsemble trainedEnsemble, int numFeatures, string innerArgs) : base(env, name) { Host.CheckValue(trainedEnsemble, nameof(trainedEnsemble)); @@ -2860,7 +2860,7 @@ protected FastTreePredictionWrapper(IHostEnvironment env, string name, TreeEnsem OutputType = NumberType.Float; } - protected FastTreePredictionWrapper(IHostEnvironment env, string name, ModelLoadContext ctx, VersionInfo ver) + protected TreeEnsembleModelParameters(IHostEnvironment env, string name, ModelLoadContext ctx, VersionInfo ver) : base(env, name, ctx) { // *** Binary format *** @@ -2933,7 +2933,7 @@ protected virtual void Map(in VBuffer src, ref Float dst) dst = (Float)TrainedEnsemble.GetOutput(in src); } - public ValueMapper> GetFeatureContributionMapper(int top, int bottom, bool normalize) + ValueMapper> IFeatureContributionMapper.GetFeatureContributionMapper(int top, int bottom, bool normalize) { Host.Check(typeof(TSrc) == typeof(VBuffer)); Host.Check(typeof(TDst) == typeof(VBuffer)); diff --git a/src/Microsoft.ML.FastTree/FastTreeClassification.cs b/src/Microsoft.ML.FastTree/FastTreeClassification.cs index 3effa46c02..dd3017704f 100644 --- a/src/Microsoft.ML.FastTree/FastTreeClassification.cs +++ b/src/Microsoft.ML.FastTree/FastTreeClassification.cs @@ -36,17 +36,17 @@ "fastrank", "fastrankwrapper")] -[assembly: LoadableClass(typeof(IPredictorProducing), typeof(FastTreeBinaryPredictor), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(IPredictorProducing), typeof(FastTreeBinaryModelParameters), null, typeof(SignatureLoadModel), "FastTree Binary Executor", - FastTreeBinaryPredictor.LoaderSignature)] + FastTreeBinaryModelParameters.LoaderSignature)] namespace Microsoft.ML.Trainers.FastTree { - public sealed class FastTreeBinaryPredictor : - FastTreePredictionWrapper + public sealed class FastTreeBinaryModelParameters : + TreeEnsembleModelParameters { - public const string LoaderSignature = "FastTreeBinaryExec"; - public const string RegistrationName = "FastTreeBinaryPredictor"; + internal const string LoaderSignature = "FastTreeBinaryExec"; + internal const string RegistrationName = "FastTreeBinaryPredictor"; private static VersionInfo GetVersionInfo() { @@ -60,7 +60,7 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010005, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(FastTreeBinaryPredictor).Assembly.FullName); + loaderAssemblyName: typeof(FastTreeBinaryModelParameters).Assembly.FullName); } protected override uint VerNumFeaturesSerialized => 0x00010002; @@ -69,12 +69,12 @@ private static VersionInfo GetVersionInfo() protected override uint VerCategoricalSplitSerialized => 0x00010005; - internal FastTreeBinaryPredictor(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) + internal FastTreeBinaryModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) { } - private FastTreeBinaryPredictor(IHostEnvironment env, ModelLoadContext ctx) + private FastTreeBinaryModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx, GetVersionInfo()) { } @@ -85,12 +85,12 @@ private protected override void SaveCore(ModelSaveContext ctx) ctx.SetVersionInfo(GetVersionInfo()); } - public static IPredictorProducing Create(IHostEnvironment env, ModelLoadContext ctx) + private static IPredictorProducing Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - var predictor = new FastTreeBinaryPredictor(env, ctx); + var predictor = new FastTreeBinaryModelParameters(env, ctx); ICalibrator calibrator; ctx.LoadModelOrNull(env, out calibrator, @"Calibrator"); if (calibrator == null) @@ -177,7 +177,7 @@ private protected override IPredictorWithFeatureWeights TrainModelCore(Tr // output probabilities when transformed using a scaled logistic function, // so transform the scores using that. - var pred = new FastTreeBinaryPredictor(Host, TrainedEnsemble, FeatureCount, InnerArgs); + var pred = new FastTreeBinaryModelParameters(Host, TrainedEnsemble, FeatureCount, InnerArgs); // FastTree's binary classification boosting framework's natural probabilistic interpretation // is explained in "From RankNet to LambdaRank to LambdaMART: An Overview" by Chris Burges. // The correctness of this scaling depends upon the gradient calculation in diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs index 451f47482e..31497d5168 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs @@ -33,9 +33,9 @@ "frrank", "btrank")] -[assembly: LoadableClass(typeof(FastTreeRankingPredictor), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(FastTreerankingModelParameters), null, typeof(SignatureLoadModel), "FastTree Ranking Executor", - FastTreeRankingPredictor.LoaderSignature)] + FastTreerankingModelParameters.LoaderSignature)] [assembly: LoadableClass(typeof(void), typeof(FastTree), null, typeof(SignatureEntryPointModule), "FastTree")] @@ -43,7 +43,7 @@ namespace Microsoft.ML.Trainers.FastTree { /// public sealed partial class FastTreeRankingTrainer - : BoostingFastTreeTrainerBase, FastTreeRankingPredictor> + : BoostingFastTreeTrainerBase, FastTreerankingModelParameters> { internal const string LoadNameValue = "FastTreeRanking"; internal const string UserNameValue = "FastTree (Boosted Trees) Ranking"; @@ -112,7 +112,7 @@ protected override float GetMaxLabel() return GetLabelGains().Length - 1; } - private protected override FastTreeRankingPredictor TrainModelCore(TrainContext context) + private protected override FastTreerankingModelParameters TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); var trainData = context.TrainingSet; @@ -126,7 +126,7 @@ private protected override FastTreeRankingPredictor TrainModelCore(TrainContext TrainCore(ch); FeatureCount = trainData.Schema.Feature.Type.ValueCount; } - return new FastTreeRankingPredictor(Host, TrainedEnsemble, FeatureCount, InnerArgs); + return new FastTreerankingModelParameters(Host, TrainedEnsemble, FeatureCount, InnerArgs); } private Double[] GetLabelGains() @@ -454,10 +454,10 @@ protected override string GetTestGraphHeader() return headerBuilder.ToString(); } - protected override RankingPredictionTransformer MakeTransformer(FastTreeRankingPredictor model, Schema trainSchema) - => new RankingPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); + protected override RankingPredictionTransformer MakeTransformer(FastTreerankingModelParameters model, Schema trainSchema) + => new RankingPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); - public RankingPredictionTransformer Train(IDataView trainData, IDataView validationData = null) + public RankingPredictionTransformer Train(IDataView trainData, IDataView validationData = null) => TrainTransformer(trainData, validationData); protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) @@ -1104,10 +1104,10 @@ private static extern unsafe void GetDerivatives( } } - public sealed class FastTreeRankingPredictor : FastTreePredictionWrapper + public sealed class FastTreerankingModelParameters : TreeEnsembleModelParameters { - public const string LoaderSignature = "FastTreeRankerExec"; - public const string RegistrationName = "FastTreeRankingPredictor"; + internal const string LoaderSignature = "FastTreeRankerExec"; + internal const string RegistrationName = "FastTreeRankingPredictor"; private static VersionInfo GetVersionInfo() { @@ -1121,7 +1121,7 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010004, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(FastTreeRankingPredictor).Assembly.FullName); + loaderAssemblyName: typeof(FastTreerankingModelParameters).Assembly.FullName); } protected override uint VerNumFeaturesSerialized => 0x00010002; @@ -1130,12 +1130,12 @@ private static VersionInfo GetVersionInfo() protected override uint VerCategoricalSplitSerialized => 0x00010005; - internal FastTreeRankingPredictor(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) + internal FastTreerankingModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) { } - private FastTreeRankingPredictor(IHostEnvironment env, ModelLoadContext ctx) + private FastTreerankingModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx, GetVersionInfo()) { } @@ -1146,9 +1146,9 @@ private protected override void SaveCore(ModelSaveContext ctx) ctx.SetVersionInfo(GetVersionInfo()); } - public static FastTreeRankingPredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static FastTreerankingModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) { - return new FastTreeRankingPredictor(env, ctx); + return new FastTreerankingModelParameters(env, ctx); } public override PredictionKind PredictionKind => PredictionKind.Ranking; diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs index a187accb26..811fbf951a 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs @@ -28,15 +28,15 @@ "frr", "btr")] -[assembly: LoadableClass(typeof(FastTreeRegressionPredictor), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(FastTreeRegressionModelParameters), null, typeof(SignatureLoadModel), "FastTree Regression Executor", - FastTreeRegressionPredictor.LoaderSignature)] + FastTreeRegressionModelParameters.LoaderSignature)] namespace Microsoft.ML.Trainers.FastTree { /// public sealed partial class FastTreeRegressionTrainer - : BoostingFastTreeTrainerBase, FastTreeRegressionPredictor> + : BoostingFastTreeTrainerBase, FastTreeRegressionModelParameters> { public const string LoadNameValue = "FastTreeRegression"; internal const string UserNameValue = "FastTree (Boosted Trees) Regression"; @@ -85,7 +85,7 @@ internal FastTreeRegressionTrainer(IHostEnvironment env, Arguments args) { } - private protected override FastTreeRegressionPredictor TrainModelCore(TrainContext context) + private protected override FastTreeRegressionModelParameters TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); var trainData = context.TrainingSet; @@ -101,7 +101,7 @@ private protected override FastTreeRegressionPredictor TrainModelCore(TrainConte ConvertData(trainData); TrainCore(ch); } - return new FastTreeRegressionPredictor(Host, TrainedEnsemble, FeatureCount, InnerArgs); + return new FastTreeRegressionModelParameters(Host, TrainedEnsemble, FeatureCount, InnerArgs); } protected override void CheckArgs(IChannel ch) @@ -164,10 +164,10 @@ protected override Test ConstructTestForTrainingData() return new RegressionTest(ConstructScoreTracker(TrainSet)); } - protected override RegressionPredictionTransformer MakeTransformer(FastTreeRegressionPredictor model, Schema trainSchema) - => new RegressionPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); + protected override RegressionPredictionTransformer MakeTransformer(FastTreeRegressionModelParameters model, Schema trainSchema) + => new RegressionPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); - public RegressionPredictionTransformer Train(IDataView trainData, IDataView validationData = null) + public RegressionPredictionTransformer Train(IDataView trainData, IDataView validationData = null) => TrainTransformer(trainData, validationData); protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) @@ -441,10 +441,10 @@ protected override void GetGradientInOneQuery(int query, int threadIndex) } } - public sealed class FastTreeRegressionPredictor : FastTreePredictionWrapper + public sealed class FastTreeRegressionModelParameters : TreeEnsembleModelParameters { - public const string LoaderSignature = "FastTreeRegressionExec"; - public const string RegistrationName = "FastTreeRegressionPredictor"; + internal const string LoaderSignature = "FastTreeRegressionExec"; + internal const string RegistrationName = "FastTreeRegressionPredictor"; private static VersionInfo GetVersionInfo() { @@ -458,7 +458,7 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010004, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(FastTreeRegressionPredictor).Assembly.FullName); + loaderAssemblyName: typeof(FastTreeRegressionModelParameters).Assembly.FullName); } protected override uint VerNumFeaturesSerialized => 0x00010002; @@ -467,12 +467,12 @@ private static VersionInfo GetVersionInfo() protected override uint VerCategoricalSplitSerialized => 0x00010005; - internal FastTreeRegressionPredictor(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) + internal FastTreeRegressionModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) { } - private FastTreeRegressionPredictor(IHostEnvironment env, ModelLoadContext ctx) + private FastTreeRegressionModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx, GetVersionInfo()) { } @@ -483,12 +483,12 @@ private protected override void SaveCore(ModelSaveContext ctx) ctx.SetVersionInfo(GetVersionInfo()); } - public static FastTreeRegressionPredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static FastTreeRegressionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return new FastTreeRegressionPredictor(env, ctx); + return new FastTreeRegressionModelParameters(env, ctx); } public override PredictionKind PredictionKind => PredictionKind.Regression; diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs index e49886884c..26ab590559 100644 --- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs +++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs @@ -23,9 +23,9 @@ FastTreeTweedieTrainer.LoadNameValue, FastTreeTweedieTrainer.ShortName)] -[assembly: LoadableClass(typeof(FastTreeTweediePredictor), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(FastTreeTweedieModelParameters), null, typeof(SignatureLoadModel), "FastTree Tweedie Regression Executor", - FastTreeTweediePredictor.LoaderSignature)] + FastTreeTweedieModelParameters.LoaderSignature)] namespace Microsoft.ML.Trainers.FastTree { @@ -34,7 +34,7 @@ namespace Microsoft.ML.Trainers.FastTree // https://arxiv.org/pdf/1508.06378.pdf /// public sealed partial class FastTreeTweedieTrainer - : BoostingFastTreeTrainerBase, FastTreeTweediePredictor> + : BoostingFastTreeTrainerBase, FastTreeTweedieModelParameters> { internal const string LoadNameValue = "FastTreeTweedieRegression"; internal const string UserNameValue = "FastTree (Boosted Trees) Tweedie Regression"; @@ -87,7 +87,7 @@ internal FastTreeTweedieTrainer(IHostEnvironment env, Arguments args) Initialize(); } - private protected override FastTreeTweediePredictor TrainModelCore(TrainContext context) + private protected override FastTreeTweedieModelParameters TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); var trainData = context.TrainingSet; @@ -104,7 +104,7 @@ private protected override FastTreeTweediePredictor TrainModelCore(TrainContext ConvertData(trainData); TrainCore(ch); } - return new FastTreeTweediePredictor(Host, TrainedEnsemble, FeatureCount, InnerArgs); + return new FastTreeTweedieModelParameters(Host, TrainedEnsemble, FeatureCount, InnerArgs); } protected override void CheckArgs(IChannel ch) @@ -316,10 +316,10 @@ protected override void Train(IChannel ch) PrintTestGraph(ch); } - protected override RegressionPredictionTransformer MakeTransformer(FastTreeTweediePredictor model, Schema trainSchema) - => new RegressionPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); + protected override RegressionPredictionTransformer MakeTransformer(FastTreeTweedieModelParameters model, Schema trainSchema) + => new RegressionPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); - public RegressionPredictionTransformer Train(IDataView trainData, IDataView validationData = null) + public RegressionPredictionTransformer Train(IDataView trainData, IDataView validationData = null) => TrainTransformer(trainData, validationData); protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) @@ -446,10 +446,10 @@ protected override void GetGradientInOneQuery(int query, int threadIndex) } } - public sealed class FastTreeTweediePredictor : FastTreePredictionWrapper + public sealed class FastTreeTweedieModelParameters : TreeEnsembleModelParameters { - public const string LoaderSignature = "FastTreeTweedieExec"; - public const string RegistrationName = "FastTreeTweediePredictor"; + internal const string LoaderSignature = "FastTreeTweedieExec"; + internal const string RegistrationName = "FastTreeTweediePredictor"; private static VersionInfo GetVersionInfo() { @@ -461,7 +461,7 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010002, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(FastTreeTweediePredictor).Assembly.FullName); + loaderAssemblyName: typeof(FastTreeTweedieModelParameters).Assembly.FullName); } protected override uint VerNumFeaturesSerialized => 0x00010001; @@ -470,12 +470,12 @@ private static VersionInfo GetVersionInfo() protected override uint VerCategoricalSplitSerialized => 0x00010003; - internal FastTreeTweediePredictor(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) + internal FastTreeTweedieModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) { } - private FastTreeTweediePredictor(IHostEnvironment env, ModelLoadContext ctx) + private FastTreeTweedieModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx, GetVersionInfo()) { } @@ -486,12 +486,12 @@ private protected override void SaveCore(ModelSaveContext ctx) ctx.SetVersionInfo(GetVersionInfo()); } - public static FastTreeTweediePredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static FastTreeTweedieModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return new FastTreeTweediePredictor(env, ctx); + return new FastTreeTweedieModelParameters(env, ctx); } protected override void Map(in VBuffer src, ref float dst) diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs index eb4a773cf9..0f2cbe5fdc 100644 --- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs +++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs @@ -25,9 +25,9 @@ FastForestClassification.ShortName, "ffc")] -[assembly: LoadableClass(typeof(IPredictorProducing), typeof(FastForestClassificationPredictor), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(IPredictorProducing), typeof(FastForestClassificationModelParameters), null, typeof(SignatureLoadModel), "FastForest Binary Executor", - FastForestClassificationPredictor.LoaderSignature)] + FastForestClassificationModelParameters.LoaderSignature)] [assembly: LoadableClass(typeof(void), typeof(FastForest), null, typeof(SignatureEntryPointModule), "FastForest")] @@ -46,11 +46,11 @@ public FastForestArgumentsBase() } } - public sealed class FastForestClassificationPredictor : - FastTreePredictionWrapper + public sealed class FastForestClassificationModelParameters : + TreeEnsembleModelParameters { - public const string LoaderSignature = "FastForestBinaryExec"; - public const string RegistrationName = "FastForestClassificationPredictor"; + internal const string LoaderSignature = "FastForestBinaryExec"; + internal const string RegistrationName = "FastForestClassificationPredictor"; private static VersionInfo GetVersionInfo() { @@ -65,7 +65,7 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010005, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(FastForestClassificationPredictor).Assembly.FullName); + loaderAssemblyName: typeof(FastForestClassificationModelParameters).Assembly.FullName); } protected override uint VerNumFeaturesSerialized => 0x00010003; @@ -79,11 +79,11 @@ private static VersionInfo GetVersionInfo() /// public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; - public FastForestClassificationPredictor(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) + public FastForestClassificationModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) { } - private FastForestClassificationPredictor(IHostEnvironment env, ModelLoadContext ctx) + private FastForestClassificationModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx, GetVersionInfo()) { } @@ -94,12 +94,12 @@ private protected override void SaveCore(ModelSaveContext ctx) ctx.SetVersionInfo(GetVersionInfo()); } - public static IPredictorProducing Create(IHostEnvironment env, ModelLoadContext ctx) + private static IPredictorProducing Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - var predictor = new FastForestClassificationPredictor(env, ctx); + var predictor = new FastForestClassificationModelParameters(env, ctx); ICalibrator calibrator; ctx.LoadModelOrNull(env, out calibrator, @"Calibrator"); if (calibrator == null) @@ -192,7 +192,7 @@ private protected override IPredictorWithFeatureWeights TrainModelCore(Tr // calibrator, transform the scores using that. // REVIEW: Need a way to signal the outside world that we prefer simple sigmoid? - return new FastForestClassificationPredictor(Host, TrainedEnsemble, FeatureCount, InnerArgs); + return new FastForestClassificationModelParameters(Host, TrainedEnsemble, FeatureCount, InnerArgs); } protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch) diff --git a/src/Microsoft.ML.FastTree/RandomForestRegression.cs b/src/Microsoft.ML.FastTree/RandomForestRegression.cs index 1b70ab32f4..8b32cea962 100644 --- a/src/Microsoft.ML.FastTree/RandomForestRegression.cs +++ b/src/Microsoft.ML.FastTree/RandomForestRegression.cs @@ -22,21 +22,21 @@ FastForestRegression.LoadNameValue, FastForestRegression.ShortName)] -[assembly: LoadableClass(typeof(FastForestRegressionPredictor), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(FastForestRegressionModelParameters), null, typeof(SignatureLoadModel), "FastForest Regression Executor", - FastForestRegressionPredictor.LoaderSignature)] + FastForestRegressionModelParameters.LoaderSignature)] namespace Microsoft.ML.Trainers.FastTree { - public sealed class FastForestRegressionPredictor : - FastTreePredictionWrapper, + public sealed class FastForestRegressionModelParameters : + TreeEnsembleModelParameters, IQuantileValueMapper, IQuantileRegressionPredictor { private readonly int _quantileSampleCount; - public const string LoaderSignature = "FastForestRegressionExec"; - public const string RegistrationName = "FastForestRegressionPredictor"; + internal const string LoaderSignature = "FastForestRegressionExec"; + internal const string RegistrationName = "FastForestRegressionPredictor"; private static VersionInfo GetVersionInfo() { @@ -51,7 +51,7 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010005, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(FastForestRegressionPredictor).Assembly.FullName); + loaderAssemblyName: typeof(FastForestRegressionModelParameters).Assembly.FullName); } protected override uint VerNumFeaturesSerialized => 0x00010003; @@ -60,13 +60,13 @@ private static VersionInfo GetVersionInfo() protected override uint VerCategoricalSplitSerialized => 0x00010006; - public FastForestRegressionPredictor(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs, int samplesCount) + public FastForestRegressionModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs, int samplesCount) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) { _quantileSampleCount = samplesCount; } - private FastForestRegressionPredictor(IHostEnvironment env, ModelLoadContext ctx) + private FastForestRegressionModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx, GetVersionInfo()) { // *** Binary format *** @@ -91,12 +91,12 @@ private protected override void SaveCore(ModelSaveContext ctx) ctx.Writer.Write(_quantileSampleCount); } - public static FastForestRegressionPredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static FastForestRegressionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return new FastForestRegressionPredictor(env, ctx); + return new FastForestRegressionModelParameters(env, ctx); } public override PredictionKind PredictionKind => PredictionKind.Regression; @@ -111,7 +111,7 @@ protected override void Map(in VBuffer src, ref float dst) dst = (float)TrainedEnsemble.GetOutput(in src) / TrainedEnsemble.NumTrees; } - public ValueMapper, VBuffer> GetMapper(float[] quantiles) + ValueMapper, VBuffer> IQuantileValueMapper.GetMapper(float[] quantiles) { return (in VBuffer src, ref VBuffer dst) => @@ -128,7 +128,7 @@ public ValueMapper, VBuffer> GetMapper(float[] quantiles) }; } - public ISchemaBindableMapper CreateMapper(Double[] quantiles) + ISchemaBindableMapper IQuantileRegressionPredictor.CreateMapper(Double[] quantiles) { Host.CheckNonEmpty(quantiles, nameof(quantiles)); return new SchemaBindableQuantileRegressionPredictor(this, quantiles); @@ -137,7 +137,7 @@ public ISchemaBindableMapper CreateMapper(Double[] quantiles) /// public sealed partial class FastForestRegression - : RandomForestTrainerBase, FastForestRegressionPredictor> + : RandomForestTrainerBase, FastForestRegressionModelParameters> { public sealed class Arguments : FastForestArgumentsBase { @@ -188,7 +188,7 @@ public FastForestRegression(IHostEnvironment env, Arguments args) { } - private protected override FastForestRegressionPredictor TrainModelCore(TrainContext context) + private protected override FastForestRegressionModelParameters TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); var trainData = context.TrainingSet; @@ -205,7 +205,7 @@ private protected override FastForestRegressionPredictor TrainModelCore(TrainCon ConvertData(trainData); TrainCore(ch); } - return new FastForestRegressionPredictor(Host, TrainedEnsemble, FeatureCount, InnerArgs, Args.QuantileSampleCount); + return new FastForestRegressionModelParameters(Host, TrainedEnsemble, FeatureCount, InnerArgs, Args.QuantileSampleCount); } protected override void PrepareLabels(IChannel ch) @@ -222,10 +222,10 @@ protected override Test ConstructTestForTrainingData() return new RegressionTest(ConstructScoreTracker(TrainSet)); } - protected override RegressionPredictionTransformer MakeTransformer(FastForestRegressionPredictor model, Schema trainSchema) - => new RegressionPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); + protected override RegressionPredictionTransformer MakeTransformer(FastForestRegressionModelParameters model, Schema trainSchema) + => new RegressionPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); - public RegressionPredictionTransformer Train(IDataView trainData, IDataView validationData = null) + public RegressionPredictionTransformer Train(IDataView trainData, IDataView validationData = null) => TrainTransformer(trainData, validationData); protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) diff --git a/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsembleCombiner.cs b/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsembleCombiner.cs index 84b04ed793..754c18f2f6 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsembleCombiner.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsembleCombiner.cs @@ -57,7 +57,7 @@ public IPredictor CombineModels(IEnumerable models) predictor = calibrated.SubPredictor; paramA = -(calibrated.Calibrator as PlattCalibrator).ParamA; } - var tree = predictor as FastTreePredictionWrapper; + var tree = predictor as TreeEnsembleModelParameters; if (tree == null) throw _host.Except("Model is not a tree ensemble"); foreach (var t in tree.TrainedEnsemble.Trees) @@ -99,14 +99,14 @@ public IPredictor CombineModels(IEnumerable models) { case PredictionKind.BinaryClassification: if (!binaryClassifier) - return new FastTreeBinaryPredictor(_host, ensemble, featureCount, null); + return new FastTreeBinaryModelParameters(_host, ensemble, featureCount, null); var cali = new PlattCalibrator(_host, -1, 0); - return new FeatureWeightsCalibratedPredictor(_host, new FastTreeBinaryPredictor(_host, ensemble, featureCount, null), cali); + return new FeatureWeightsCalibratedPredictor(_host, new FastTreeBinaryModelParameters(_host, ensemble, featureCount, null), cali); case PredictionKind.Regression: - return new FastTreeRegressionPredictor(_host, ensemble, featureCount, null); + return new FastTreeRegressionModelParameters(_host, ensemble, featureCount, null); case PredictionKind.Ranking: - return new FastTreeRankingPredictor(_host, ensemble, featureCount, null); + return new FastTreerankingModelParameters(_host, ensemble, featureCount, null); default: _host.Assert(false); throw _host.ExceptNotSupp(); diff --git a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs index 1e9cda953f..065faea1b3 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs @@ -260,7 +260,7 @@ private sealed class State { private readonly IExceptionContext _ectx; private readonly Row _input; - private readonly FastTreePredictionWrapper _ensemble; + private readonly TreeEnsembleModelParameters _ensemble; private readonly int _numTrees; private readonly int _numLeaves; @@ -276,7 +276,7 @@ private sealed class State private long _cachedLeafBuilderPosition; private long _cachedPathBuilderPosition; - public State(IExceptionContext ectx, Row input, FastTreePredictionWrapper ensemble, int numLeaves, int featureIndex) + public State(IExceptionContext ectx, Row input, TreeEnsembleModelParameters ensemble, int numLeaves, int featureIndex) { Contracts.AssertValue(ectx); _ectx = ectx; @@ -422,7 +422,7 @@ private static VersionInfo GetVersionInfo() } private readonly IHost _host; - private readonly FastTreePredictionWrapper _ensemble; + private readonly TreeEnsembleModelParameters _ensemble; private readonly int _totalLeafCount; public TreeEnsembleFeaturizerBindableMapper(IHostEnvironment env, Arguments args, IPredictor predictor) @@ -434,7 +434,7 @@ public TreeEnsembleFeaturizerBindableMapper(IHostEnvironment env, Arguments args if (predictor is CalibratedPredictorBase) predictor = ((CalibratedPredictorBase)predictor).SubPredictor; - _ensemble = predictor as FastTreePredictionWrapper; + _ensemble = predictor as TreeEnsembleModelParameters; _host.Check(_ensemble != null, "Predictor in model file does not have compatible type"); _totalLeafCount = CountLeaves(_ensemble); @@ -449,7 +449,7 @@ public TreeEnsembleFeaturizerBindableMapper(IHostEnvironment env, ModelLoadConte // *** Binary format *** // ensemble - ctx.LoadModel(env, out _ensemble, "Ensemble"); + ctx.LoadModel(env, out _ensemble, "Ensemble"); _totalLeafCount = CountLeaves(_ensemble); } @@ -466,7 +466,7 @@ public void Save(ModelSaveContext ctx) ctx.SaveModel(_ensemble, "Ensemble"); } - private static int CountLeaves(FastTreePredictionWrapper ensemble) + private static int CountLeaves(TreeEnsembleModelParameters ensemble) { Contracts.AssertValue(ensemble); @@ -644,7 +644,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV // Make sure that the given predictor has the correct number of input features. if (predictor is CalibratedPredictorBase) predictor = ((CalibratedPredictorBase)predictor).SubPredictor; - // Predictor should be a FastTreePredictionWrapper, which implements IValueMapper, so this should + // Predictor should be a TreeEnsembleModelParameters, which implements IValueMapper, so this should // be non-null. var vm = predictor as IValueMapper; ch.CheckUserArg(vm != null, nameof(args.TrainedModelFile), "Predictor in model file does not have compatible type"); @@ -708,7 +708,7 @@ public static IDataTransform CreateForEntryPoint(IHostEnvironment env, Arguments // Make sure that the given predictor has the correct number of input features. if (predictor is CalibratedPredictorBase) predictor = ((CalibratedPredictorBase)predictor).SubPredictor; - // Predictor should be a FastTreePredictionWrapper, which implements IValueMapper, so this should + // Predictor should be a TreeEnsembleModelParameters, which implements IValueMapper, so this should // be non-null. var vm = predictor as IValueMapper; ch.CheckUserArg(vm != null, nameof(args.PredictorModel), "Predictor does not have compatible type"); diff --git a/src/Microsoft.ML.FastTree/TreeTrainersStatic.cs b/src/Microsoft.ML.FastTree/TreeTrainersStatic.cs index 176f0cead8..debe52f7bd 100644 --- a/src/Microsoft.ML.FastTree/TreeTrainersStatic.cs +++ b/src/Microsoft.ML.FastTree/TreeTrainersStatic.cs @@ -48,7 +48,7 @@ public static Scalar FastTree(this RegressionContext.RegressionTrainers c int minDatapointsInLeaves = Defaults.MinDocumentsInLeaves, double learningRate = Defaults.LearningRates, Action advancedSettings = null, - Action onFit = null) + Action onFit = null) { CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeaves, learningRate, advancedSettings, onFit); @@ -144,7 +144,7 @@ public static Scalar FastTree(this RankingContext.RankingTrainers c int minDatapointsInLeaves = Defaults.MinDocumentsInLeaves, double learningRate = Defaults.LearningRates, Action advancedSettings = null, - Action onFit = null) + Action onFit = null) { CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeaves, learningRate, advancedSettings, onFit); diff --git a/src/Microsoft.ML.Legacy/AssemblyRegistration.cs b/src/Microsoft.ML.Legacy/AssemblyRegistration.cs index 3766199c56..5f633a0f67 100644 --- a/src/Microsoft.ML.Legacy/AssemblyRegistration.cs +++ b/src/Microsoft.ML.Legacy/AssemblyRegistration.cs @@ -43,7 +43,7 @@ private static bool LoadStandardAssemblies() _ = typeof(TextLoader).Assembly; // ML.Data //_ = typeof(EnsemblePredictor).Assembly); // ML.Ensemble BUG https://github.com/dotnet/machinelearning/issues/1078 Ensemble isn't in a NuGet package - _ = typeof(FastTreeBinaryPredictor).Assembly; // ML.FastTree + _ = typeof(FastTreeBinaryModelParameters).Assembly; // ML.FastTree _ = typeof(KMeansModelParameters).Assembly; // ML.KMeansClustering _ = typeof(Maml).Assembly; // ML.Maml _ = typeof(PcaPredictor).Assembly; // ML.PCA diff --git a/src/Microsoft.ML.Legacy/CSharpApi.cs b/src/Microsoft.ML.Legacy/CSharpApi.cs index bcd35d4988..c7e5c3ff89 100644 --- a/src/Microsoft.ML.Legacy/CSharpApi.cs +++ b/src/Microsoft.ML.Legacy/CSharpApi.cs @@ -400,20 +400,6 @@ public void Add(Microsoft.ML.Legacy.Models.PAVCalibrator input, Microsoft.ML.Leg _jsonNodes.Add(Serialize("Models.PAVCalibrator", input, output)); } - [Obsolete] - public Microsoft.ML.Legacy.Models.PipelineSweeper.Output Add(Microsoft.ML.Legacy.Models.PipelineSweeper input) - { - var output = new Microsoft.ML.Legacy.Models.PipelineSweeper.Output(); - Add(input, output); - return output; - } - - [Obsolete] - public void Add(Microsoft.ML.Legacy.Models.PipelineSweeper input, Microsoft.ML.Legacy.Models.PipelineSweeper.Output output) - { - _jsonNodes.Add(Serialize("Models.PipelineSweeper", input, output)); - } - [Obsolete] public Microsoft.ML.Legacy.Models.PlattCalibrator.Output Add(Microsoft.ML.Legacy.Models.PlattCalibrator input) { @@ -512,20 +498,6 @@ public void Add(Microsoft.ML.Legacy.Models.Summarizer input, Microsoft.ML.Legacy _jsonNodes.Add(Serialize("Models.Summarizer", input, output)); } - [Obsolete] - public Microsoft.ML.Legacy.Models.SweepResultExtractor.Output Add(Microsoft.ML.Legacy.Models.SweepResultExtractor input) - { - var output = new Microsoft.ML.Legacy.Models.SweepResultExtractor.Output(); - Add(input, output); - return output; - } - - [Obsolete] - public void Add(Microsoft.ML.Legacy.Models.SweepResultExtractor input, Microsoft.ML.Legacy.Models.SweepResultExtractor.Output output) - { - _jsonNodes.Add(Serialize("Models.SweepResultExtractor", input, output)); - } - [Obsolete] public Microsoft.ML.Legacy.Models.TrainTestBinaryEvaluator.Output Add(Microsoft.ML.Legacy.Models.TrainTestBinaryEvaluator input) { @@ -4121,120 +4093,6 @@ public PAVCalibratorPipelineStep(Output output) } } - namespace Legacy.Models - { - - /// - /// AutoML pipeline sweeping optimzation macro. - /// - [Obsolete] - public sealed partial class PipelineSweeper - { - - - /// - /// The data to be used for training. - /// - [Obsolete] - public Var TrainingData { get; set; } = new Var(); - - /// - /// The data to be used for testing. - /// - [Obsolete] - public Var TestingData { get; set; } = new Var(); - - /// - /// The arguments for creating an AutoMlState component. - /// - [JsonConverter(typeof(ComponentSerializer))] - [Obsolete] - public AutoMlStateBase StateArguments { get; set; } - - /// - /// The stateful object conducting of the autoML search. - /// - [Obsolete] - public Var State { get; set; } = new Var(); - - /// - /// Number of candidate pipelines to retrieve each round. - /// - [Obsolete] - public int BatchSize { get; set; } - - /// - /// Output datasets from previous iteration of sweep. - /// - [Obsolete] - public ArrayVar CandidateOutputs { get; set; } = new ArrayVar(); - - /// - /// Column(s) to use as Role 'Label' - /// - [Obsolete] - public string[] LabelColumns { get; set; } - - /// - /// Column(s) to use as Role 'Group' - /// - [Obsolete] - public string[] GroupColumns { get; set; } - - /// - /// Column(s) to use as Role 'Weight' - /// - [Obsolete] - public string[] WeightColumns { get; set; } - - /// - /// Column(s) to use as Role 'Name' - /// - [Obsolete] - public string[] NameColumns { get; set; } - - /// - /// Column(s) to use as Role 'NumericFeature' - /// - [Obsolete] - public string[] NumericFeatureColumns { get; set; } - - /// - /// Column(s) to use as Role 'CategoricalFeature' - /// - [Obsolete] - public string[] CategoricalFeatureColumns { get; set; } - - /// - /// Column(s) to use as Role 'TextFeature' - /// - [Obsolete] - public string[] TextFeatureColumns { get; set; } - - /// - /// Column(s) to use as Role 'ImagePath' - /// - [Obsolete] - public string[] ImagePathColumns { get; set; } - - - [Obsolete] - public sealed class Output - { - /// - /// Stateful autoML object, keeps track of where the search in progress. - /// - public Var State { get; set; } = new Var(); - - /// - /// Results of the sweep, including pipelines (as graph strings), IDs, and metric values. - /// - public Var Results { get; set; } = new Var(); - - } - } - } - namespace Legacy.Models { @@ -4670,41 +4528,6 @@ public sealed class Output } } - namespace Legacy.Models - { - - /// - /// Extracts the sweep result. - /// - [Obsolete] - public sealed partial class SweepResultExtractor - { - - - /// - /// The stateful object conducting of the autoML search. - /// - [Obsolete] - public Var State { get; set; } = new Var(); - - - [Obsolete] - public sealed class Output - { - /// - /// Stateful autoML object, keeps track of where the search in progress. - /// - public Var State { get; set; } = new Var(); - - /// - /// Results of the sweep, including pipelines (as graph strings), IDs, and metric values. - /// - public Var Results { get; set; } = new Var(); - - } - } - } - namespace Legacy.Models { @@ -20389,150 +20212,6 @@ public WordTokenizerPipelineStep(Output output) namespace Runtime { - [Obsolete] - public abstract class AutoMlEngine : ComponentKind {} - - - - /// - /// AutoML engine that returns learners with default settings. - /// - [Obsolete] - public sealed class DefaultsAutoMlEngine : AutoMlEngine - { - [Obsolete] - internal override string ComponentName => "Defaults"; - } - - - - /// - /// AutoML engine that consists of distinct, hierarchical stages of operation. - /// - [Obsolete] - public sealed class RocketAutoMlEngine : AutoMlEngine - { - /// - /// Number of learners to retain for second stage. - /// - [Obsolete] - public int TopKLearners { get; set; } = 2; - - /// - /// Number of trials for retained second stage learners. - /// - [Obsolete] - public int SecondRoundTrialsPerLearner { get; set; } = 5; - - /// - /// Use random initialization only. - /// - [Obsolete] - public bool RandomInitialization { get; set; } = false; - - /// - /// Number of initilization pipelines, used for random initialization only. - /// - [Obsolete] - public int NumInitializationPipelines { get; set; } = 20; - - [Obsolete] - internal override string ComponentName => "Rocket"; - } - - - - /// - /// AutoML engine using uniform random sampling. - /// - [Obsolete] - public sealed class UniformRandomAutoMlEngine : AutoMlEngine - { - [Obsolete] - internal override string ComponentName => "UniformRandom"; - } - - [Obsolete] - public abstract class AutoMlStateBase : ComponentKind {} - - [Obsolete] - public enum PipelineSweeperSupportedMetricsMetrics - { - Auc = 0, - AccuracyMicro = 1, - AccuracyMacro = 2, - L1 = 3, - L2 = 4, - F1 = 5, - AuPrc = 6, - TopKAccuracy = 7, - Rms = 8, - LossFn = 9, - RSquared = 10, - LogLoss = 11, - LogLossReduction = 12, - Ndcg = 13, - Dcg = 14, - PositivePrecision = 15, - PositiveRecall = 16, - NegativePrecision = 17, - NegativeRecall = 18, - DrAtK = 19, - DrAtPFpr = 20, - DrAtNumPos = 21, - NumAnomalies = 22, - ThreshAtK = 23, - ThreshAtP = 24, - ThreshAtNumPos = 25, - Nmi = 26, - AvgMinScore = 27, - Dbi = 28 - } - - - - /// - /// State of an AutoML search and search space. - /// - [Obsolete] - public sealed class AutoMlStateAutoMlStateBase : AutoMlStateBase - { - /// - /// Supported metric for evaluator. - /// - [Obsolete] - public PipelineSweeperSupportedMetricsMetrics Metric { get; set; } = PipelineSweeperSupportedMetricsMetrics.Auc; - - /// - /// AutoML engine (pipeline optimizer) that generates next candidates. - /// - [JsonConverter(typeof(ComponentSerializer))] - [Obsolete] - public AutoMlEngine Engine { get; set; } - - /// - /// Kind of trainer for task, such as binary classification trainer, multiclass trainer, etc. - /// - [Obsolete] - public Microsoft.ML.Legacy.Models.MacroUtilsTrainerKinds TrainerKind { get; set; } = Microsoft.ML.Legacy.Models.MacroUtilsTrainerKinds.SignatureBinaryClassifierTrainer; - - /// - /// Arguments for creating terminator, which determines when to stop search. - /// - [JsonConverter(typeof(ComponentSerializer))] - [Obsolete] - public SearchTerminator TerminatorArgs { get; set; } - - /// - /// Learner set to sweep over (if available). - /// - [Obsolete] - public string[] RequestedLearners { get; set; } - - [Obsolete] - internal override string ComponentName => "AutoMlState"; - } - [Obsolete] public abstract class BoosterParameterFunction : ComponentKind {} @@ -23634,27 +23313,6 @@ public sealed class SquaredLossSDCARegressionLossFunction : SDCARegressionLossFu internal override string ComponentName => "SquaredLoss"; } - [Obsolete] - public abstract class SearchTerminator : ComponentKind {} - - - - /// - /// Terminators a sweep based on total number of iterations. - /// - [Obsolete] - public sealed class IterationLimitedSearchTerminator : SearchTerminator - { - /// - /// Total number of iterations. - /// - [Obsolete] - public int FinalHistoryLength { get; set; } - - [Obsolete] - internal override string ComponentName => "IterationLimited"; - } - [Obsolete] public abstract class StopWordsRemover : ComponentKind {} diff --git a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs index 4217ee0612..cdf06e19f1 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs @@ -20,16 +20,16 @@ new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer) }, LightGbmBinaryTrainer.UserName, LightGbmBinaryTrainer.LoadNameValue, LightGbmBinaryTrainer.ShortName, DocName = "trainer/LightGBM.md")] -[assembly: LoadableClass(typeof(IPredictorProducing), typeof(LightGbmBinaryPredictor), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(IPredictorProducing), typeof(LightGbmBinaryModelParameters), null, typeof(SignatureLoadModel), "LightGBM Binary Executor", - LightGbmBinaryPredictor.LoaderSignature)] + LightGbmBinaryModelParameters.LoaderSignature)] [assembly: LoadableClass(typeof(void), typeof(LightGbm), null, typeof(SignatureEntryPointModule), "LightGBM")] namespace Microsoft.ML.Runtime.LightGBM { /// - public sealed class LightGbmBinaryPredictor : FastTreePredictionWrapper + public sealed class LightGbmBinaryModelParameters : TreeEnsembleModelParameters { internal const string LoaderSignature = "LightGBMBinaryExec"; internal const string RegistrationName = "LightGBMBinaryPredictor"; @@ -47,7 +47,7 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010004, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(LightGbmBinaryPredictor).Assembly.FullName); + loaderAssemblyName: typeof(LightGbmBinaryModelParameters).Assembly.FullName); } protected override uint VerNumFeaturesSerialized => 0x00010002; @@ -55,12 +55,12 @@ private static VersionInfo GetVersionInfo() protected override uint VerCategoricalSplitSerialized => 0x00010005; public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; - internal LightGbmBinaryPredictor(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) + internal LightGbmBinaryModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) { } - private LightGbmBinaryPredictor(IHostEnvironment env, ModelLoadContext ctx) + private LightGbmBinaryModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx, GetVersionInfo()) { } @@ -71,12 +71,12 @@ private protected override void SaveCore(ModelSaveContext ctx) ctx.SetVersionInfo(GetVersionInfo()); } - public static IPredictorProducing Create(IHostEnvironment env, ModelLoadContext ctx) + private static IPredictorProducing Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - var predictor = new LightGbmBinaryPredictor(env, ctx); + var predictor = new LightGbmBinaryModelParameters(env, ctx); ICalibrator calibrator; ctx.LoadModelOrNull(env, out calibrator, @"Calibrator"); if (calibrator == null) @@ -132,7 +132,7 @@ private protected override IPredictorWithFeatureWeights CreatePredictor() { Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete"); var innerArgs = LightGbmInterfaceUtils.JoinParameters(Options); - var pred = new LightGbmBinaryPredictor(Host, TrainedEnsemble, FeatureCount, innerArgs); + var pred = new LightGbmBinaryModelParameters(Host, TrainedEnsemble, FeatureCount, innerArgs); var cali = new PlattCalibrator(Host, -0.5, 0); return new FeatureWeightsCalibratedPredictor(Host, pred, cali); } diff --git a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs index 62b476b40f..f829dfd095 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs @@ -82,9 +82,9 @@ private TreeEnsemble GetBinaryEnsemble(int classID) return res; } - private LightGbmBinaryPredictor CreateBinaryPredictor(int classID, string innerArgs) + private LightGbmBinaryModelParameters CreateBinaryPredictor(int classID, string innerArgs) { - return new LightGbmBinaryPredictor(Host, GetBinaryEnsemble(classID), FeatureCount, innerArgs); + return new LightGbmBinaryModelParameters(Host, GetBinaryEnsemble(classID), FeatureCount, innerArgs); } private protected override OvaPredictor CreatePredictor() diff --git a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs index 4f0d49b75e..9d2e6a45d8 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs @@ -18,17 +18,17 @@ new[] { typeof(SignatureRankerTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer) }, "LightGBM Ranking", LightGbmRankingTrainer.LoadNameValue, LightGbmRankingTrainer.ShortName, DocName = "trainer/LightGBM.md")] -[assembly: LoadableClass(typeof(LightGbmRankingPredictor), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(LightGbmRankingModelParameters), null, typeof(SignatureLoadModel), "LightGBM Ranking Executor", - LightGbmRankingPredictor.LoaderSignature)] + LightGbmRankingModelParameters.LoaderSignature)] namespace Microsoft.ML.Runtime.LightGBM { - public sealed class LightGbmRankingPredictor : FastTreePredictionWrapper + public sealed class LightGbmRankingModelParameters : TreeEnsembleModelParameters { - public const string LoaderSignature = "LightGBMRankerExec"; - public const string RegistrationName = "LightGBMRankingPredictor"; + internal const string LoaderSignature = "LightGBMRankerExec"; + internal const string RegistrationName = "LightGBMRankingPredictor"; private static VersionInfo GetVersionInfo() { @@ -43,7 +43,7 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010004, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(LightGbmRankingPredictor).Assembly.FullName); + loaderAssemblyName: typeof(LightGbmRankingModelParameters).Assembly.FullName); } protected override uint VerNumFeaturesSerialized => 0x00010002; @@ -51,12 +51,12 @@ private static VersionInfo GetVersionInfo() protected override uint VerCategoricalSplitSerialized => 0x00010005; public override PredictionKind PredictionKind => PredictionKind.Ranking; - internal LightGbmRankingPredictor(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) + internal LightGbmRankingModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) { } - private LightGbmRankingPredictor(IHostEnvironment env, ModelLoadContext ctx) + private LightGbmRankingModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx, GetVersionInfo()) { } @@ -67,14 +67,14 @@ private protected override void SaveCore(ModelSaveContext ctx) ctx.SetVersionInfo(GetVersionInfo()); } - private static LightGbmRankingPredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static LightGbmRankingModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) { - return new LightGbmRankingPredictor(env, ctx); + return new LightGbmRankingModelParameters(env, ctx); } } /// - public sealed class LightGbmRankingTrainer : LightGbmTrainerBase, LightGbmRankingPredictor> + public sealed class LightGbmRankingTrainer : LightGbmTrainerBase, LightGbmRankingModelParameters> { public const string UserName = "LightGBM Ranking"; public const string LoadNameValue = "LightGBMRanking"; @@ -151,11 +151,11 @@ protected override void CheckLabelCompatible(SchemaShape.Column labelCol) error(); } - private protected override LightGbmRankingPredictor CreatePredictor() + private protected override LightGbmRankingModelParameters CreatePredictor() { Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete"); var innerArgs = LightGbmInterfaceUtils.JoinParameters(Options); - return new LightGbmRankingPredictor(Host, TrainedEnsemble, FeatureCount, innerArgs); + return new LightGbmRankingModelParameters(Host, TrainedEnsemble, FeatureCount, innerArgs); } protected override void CheckAndUpdateParametersBeforeTraining(IChannel ch, RoleMappedData data, float[] labels, int[] groups) @@ -178,10 +178,10 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc }; } - protected override RankingPredictionTransformer MakeTransformer(LightGbmRankingPredictor model, Schema trainSchema) - => new RankingPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); + protected override RankingPredictionTransformer MakeTransformer(LightGbmRankingModelParameters model, Schema trainSchema) + => new RankingPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); - public RankingPredictionTransformer Train(IDataView trainData, IDataView validationData = null) + public RankingPredictionTransformer Train(IDataView trainData, IDataView validationData = null) => TrainTransformer(trainData, validationData); } diff --git a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs index feeb227ba4..5d2e2f1aa7 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs @@ -18,17 +18,17 @@ new[] { typeof(SignatureRegressorTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer) }, LightGbmRegressorTrainer.UserNameValue, LightGbmRegressorTrainer.LoadNameValue, LightGbmRegressorTrainer.ShortName, DocName = "trainer/LightGBM.md")] -[assembly: LoadableClass(typeof(LightGbmRegressionPredictor), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(LightGbmRegressionModelParameters), null, typeof(SignatureLoadModel), "LightGBM Regression Executor", - LightGbmRegressionPredictor.LoaderSignature)] + LightGbmRegressionModelParameters.LoaderSignature)] namespace Microsoft.ML.Runtime.LightGBM { /// - public sealed class LightGbmRegressionPredictor : FastTreePredictionWrapper + public sealed class LightGbmRegressionModelParameters : TreeEnsembleModelParameters { - public const string LoaderSignature = "LightGBMRegressionExec"; - public const string RegistrationName = "LightGBMRegressionPredictor"; + internal const string LoaderSignature = "LightGBMRegressionExec"; + internal const string RegistrationName = "LightGBMRegressionPredictor"; private static VersionInfo GetVersionInfo() { @@ -43,7 +43,7 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010004, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(LightGbmRegressionPredictor).Assembly.FullName); + loaderAssemblyName: typeof(LightGbmRegressionModelParameters).Assembly.FullName); } protected override uint VerNumFeaturesSerialized => 0x00010002; @@ -51,12 +51,12 @@ private static VersionInfo GetVersionInfo() protected override uint VerCategoricalSplitSerialized => 0x00010005; public override PredictionKind PredictionKind => PredictionKind.Regression; - internal LightGbmRegressionPredictor(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) + internal LightGbmRegressionModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) { } - private LightGbmRegressionPredictor(IHostEnvironment env, ModelLoadContext ctx) + private LightGbmRegressionModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx, GetVersionInfo()) { } @@ -67,17 +67,17 @@ private protected override void SaveCore(ModelSaveContext ctx) ctx.SetVersionInfo(GetVersionInfo()); } - public static LightGbmRegressionPredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static LightGbmRegressionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return new LightGbmRegressionPredictor(env, ctx); + return new LightGbmRegressionModelParameters(env, ctx); } } /// - public sealed class LightGbmRegressorTrainer : LightGbmTrainerBase, LightGbmRegressionPredictor> + public sealed class LightGbmRegressorTrainer : LightGbmTrainerBase, LightGbmRegressionModelParameters> { internal const string Summary = "LightGBM Regression"; internal const string LoadNameValue = "LightGBMRegression"; @@ -119,12 +119,12 @@ internal LightGbmRegressorTrainer(IHostEnvironment env, LightGbmArguments args) { } - private protected override LightGbmRegressionPredictor CreatePredictor() + private protected override LightGbmRegressionModelParameters CreatePredictor() { Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete"); var innerArgs = LightGbmInterfaceUtils.JoinParameters(Options); - return new LightGbmRegressionPredictor(Host, TrainedEnsemble, FeatureCount, innerArgs); + return new LightGbmRegressionModelParameters(Host, TrainedEnsemble, FeatureCount, innerArgs); } protected override void CheckDataValid(IChannel ch, RoleMappedData data) @@ -155,10 +155,10 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc }; } - protected override RegressionPredictionTransformer MakeTransformer(LightGbmRegressionPredictor model, Schema trainSchema) - => new RegressionPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); + protected override RegressionPredictionTransformer MakeTransformer(LightGbmRegressionModelParameters model, Schema trainSchema) + => new RegressionPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); - public RegressionPredictionTransformer Train(IDataView trainData, IDataView validationData = null) + public RegressionPredictionTransformer Train(IDataView trainData, IDataView validationData = null) => TrainTransformer(trainData, validationData); } diff --git a/src/Microsoft.ML.LightGBM/LightGbmStatic.cs b/src/Microsoft.ML.LightGBM/LightGbmStatic.cs index 659e6da141..3cdad93257 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmStatic.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmStatic.cs @@ -49,7 +49,7 @@ public static Scalar LightGbm(this RegressionContext.RegressionTrainers c double? learningRate = null, int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, Action advancedSettings = null, - Action onFit = null) + Action onFit = null) { CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit); @@ -144,7 +144,7 @@ public static Scalar LightGbm(this RankingContext.RankingTrainers c double? learningRate = null, int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, Action advancedSettings = null, - Action onFit = null) + Action onFit = null) { CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit); Contracts.CheckValue(groupId, nameof(groupId)); diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs index 368b0c0189..45978c9d52 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs @@ -389,7 +389,7 @@ public virtual void GetFeatureWeights(ref VBuffer weights) Weight.CopyTo(ref weights); } - public ValueMapper> GetFeatureContributionMapper(int top, int bottom, bool normalize) + ValueMapper> IFeatureContributionMapper.GetFeatureContributionMapper(int top, int bottom, bool normalize) { Contracts.Check(typeof(TSrc) == typeof(VBuffer)); Contracts.Check(typeof(TDstContributions) == typeof(VBuffer)); diff --git a/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs b/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs index cf4de98e66..9c7bf0199b 100644 --- a/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs +++ b/src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs @@ -106,13 +106,13 @@ public ParameterSet[] ProposeSweeps(int maxSweeps, IEnumerable previ } // Fit Random Forest Model on previous run data. - FastForestRegressionPredictor forestPredictor = FitModel(viableRuns); + FastForestRegressionModelParameters forestPredictor = FitModel(viableRuns); // Using acquisition function and current best, get candidate configuration(s). return GenerateCandidateConfigurations(numOfCandidates, viableRuns, forestPredictor); } - private FastForestRegressionPredictor FitModel(IEnumerable previousRuns) + private FastForestRegressionModelParameters FitModel(IEnumerable previousRuns) { Single[] targets = new Single[previousRuns.Count()]; Single[][] features = new Single[previousRuns.Count()][]; @@ -160,7 +160,7 @@ private FastForestRegressionPredictor FitModel(IEnumerable previousR /// History of previously evaluated points, with their emprical performance values. /// Trained random forest ensemble. Used in evaluating the candidates. /// An array of ParamaterSets which are the candidate configurations to sweep. - private ParameterSet[] GenerateCandidateConfigurations(int numOfCandidates, IEnumerable previousRuns, FastForestRegressionPredictor forest) + private ParameterSet[] GenerateCandidateConfigurations(int numOfCandidates, IEnumerable previousRuns, FastForestRegressionModelParameters forest) { // Get k best previous runs ParameterSets. ParameterSet[] bestKParamSets = GetKBestConfigurations(previousRuns, forest, _args.LocalSearchParentCount); @@ -188,7 +188,7 @@ private ParameterSet[] GenerateCandidateConfigurations(int numOfCandidates, IEnu /// Number of candidate configurations returned by the method (top K). /// Historical run results. /// Array of parameter sets, which will then be evaluated. - private ParameterSet[] GreedyPlusRandomSearch(ParameterSet[] parents, FastForestRegressionPredictor forest, int numOfCandidates, IEnumerable previousRuns) + private ParameterSet[] GreedyPlusRandomSearch(ParameterSet[] parents, FastForestRegressionModelParameters forest, int numOfCandidates, IEnumerable previousRuns) { // REVIEW: The IsMetricMaximizing flag affects the comparator, so that // performing Max() should get the best, regardless of if it is maximizing or @@ -231,7 +231,7 @@ private ParameterSet[] GreedyPlusRandomSearch(ParameterSet[] parents, FastForest /// Best performance seen thus far. /// Threshold for when to stop the local search. /// - private Tuple LocalSearch(ParameterSet parent, FastForestRegressionPredictor forest, double bestVal, double epsilon) + private Tuple LocalSearch(ParameterSet parent, FastForestRegressionModelParameters forest, double bestVal, double epsilon) { try { @@ -332,7 +332,7 @@ private ParameterSet[] GetOneMutationNeighborhood(ParameterSet parent) /// Trained forest predictor, used for filtering configs. /// Parameter configurations. /// 2D array where rows correspond to configurations, and columns to the predicted leaf values. - private double[][] GetForestRegressionLeafValues(FastForestRegressionPredictor forest, ParameterSet[] configs) + private double[][] GetForestRegressionLeafValues(FastForestRegressionModelParameters forest, ParameterSet[] configs) { List datasetLeafValues = new List(); var e = forest.TrainedEnsemble; @@ -369,14 +369,14 @@ private double[][] ComputeForestStats(double[][] leafValues) return meansAndStdDevs; } - private double[] EvaluateConfigurationsByEI(FastForestRegressionPredictor forest, double bestVal, ParameterSet[] configs) + private double[] EvaluateConfigurationsByEI(FastForestRegressionModelParameters forest, double bestVal, ParameterSet[] configs) { double[][] leafPredictions = GetForestRegressionLeafValues(forest, configs); double[][] forestStatistics = ComputeForestStats(leafPredictions); return ComputeEIs(bestVal, forestStatistics); } - private ParameterSet[] GetKBestConfigurations(IEnumerable previousRuns, FastForestRegressionPredictor forest, int k = 10) + private ParameterSet[] GetKBestConfigurations(IEnumerable previousRuns, FastForestRegressionModelParameters forest, int k = 10) { // NOTE: Should we change this to rank according to EI (using forest), instead of observed performance? diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index 107c83eb7d..a32ded52bf 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -347,7 +347,7 @@ public void EntryPointCatalogCheckDuplicateParams() private (IEnumerable epListContents, JObject manifest) BuildManifests() { - Env.ComponentCatalog.RegisterAssembly(typeof(LightGbmBinaryPredictor).Assembly); + Env.ComponentCatalog.RegisterAssembly(typeof(LightGbmBinaryModelParameters).Assembly); Env.ComponentCatalog.RegisterAssembly(typeof(TensorFlowTransform).Assembly); Env.ComponentCatalog.RegisterAssembly(typeof(ImageLoaderTransform).Assembly); Env.ComponentCatalog.RegisterAssembly(typeof(SymSgdClassificationTrainer).Assembly); @@ -1952,14 +1952,14 @@ public void EntryPointEvaluateRanking() [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // LightGBM is 64-bit only public void EntryPointLightGbmBinary() { - Env.ComponentCatalog.RegisterAssembly(typeof(LightGbmBinaryPredictor).Assembly); + Env.ComponentCatalog.RegisterAssembly(typeof(LightGbmBinaryModelParameters).Assembly); TestEntryPointRoutine("breast-cancer.txt", "Trainers.LightGbmBinaryClassifier"); } [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // LightGBM is 64-bit only public void EntryPointLightGbmMultiClass() { - Env.ComponentCatalog.RegisterAssembly(typeof(LightGbmBinaryPredictor).Assembly); + Env.ComponentCatalog.RegisterAssembly(typeof(LightGbmBinaryModelParameters).Assembly); TestEntryPointRoutine(GetDataPath(@"iris.txt"), "Trainers.LightGbmClassifier"); } diff --git a/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs b/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs index f96f6f205a..b763d3ad01 100644 --- a/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs +++ b/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs @@ -43,7 +43,7 @@ protected override void InitializeEnvironment(IHostEnvironment environment) { base.InitializeEnvironment(environment); - environment.ComponentCatalog.RegisterAssembly(typeof(LightGbmBinaryPredictor).Assembly); + environment.ComponentCatalog.RegisterAssembly(typeof(LightGbmBinaryModelParameters).Assembly); environment.ComponentCatalog.RegisterAssembly(typeof(SymSgdClassificationTrainer).Assembly); } diff --git a/test/Microsoft.ML.StaticPipelineTesting/Training.cs b/test/Microsoft.ML.StaticPipelineTesting/Training.cs index 9d4a2bc010..1a8f78b794 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/Training.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/Training.cs @@ -423,7 +423,7 @@ public void FastTreeRegression() c => (label: c.LoadFloat(11), features: c.LoadFloat(0, 10)), separator: ';', hasHeader: true); - FastTreeRegressionPredictor pred = null; + FastTreeRegressionModelParameters pred = null; var est = reader.MakeNewEstimator() .Append(r => (r.label, score: ctx.Trainers.FastTree(r.label, r.features, @@ -505,7 +505,7 @@ public void LightGbmRegression() c => (label: c.LoadFloat(11), features: c.LoadFloat(0, 10)), separator: ';', hasHeader: true); - LightGbmRegressionPredictor pred = null; + LightGbmRegressionModelParameters pred = null; var est = reader.MakeNewEstimator() .Append(r => (r.label, score: ctx.Trainers.LightGbm(r.label, r.features, @@ -770,7 +770,7 @@ public void FastTreeRanking() c => (label: c.LoadFloat(0), features: c.LoadFloat(9, 14), groupId: c.LoadText(1)), separator: '\t', hasHeader: true); - FastTreeRankingPredictor pred = null; + FastTreerankingModelParameters pred = null; var est = reader.MakeNewEstimator() .Append(r => (r.label, r.features, groupId: r.groupId.ToKey())) @@ -811,7 +811,7 @@ public void LightGBMRanking() c => (label: c.LoadFloat(0), features: c.LoadFloat(9, 14), groupId: c.LoadText(1)), separator: '\t', hasHeader: true); - LightGbmRankingPredictor pred = null; + LightGbmRankingModelParameters pred = null; var est = reader.MakeNewEstimator() .Append(r => (r.label, r.features, groupId: r.groupId.ToKey())) diff --git a/test/Microsoft.ML.TestFramework/EnvironmentExtensions.cs b/test/Microsoft.ML.TestFramework/EnvironmentExtensions.cs index ccd041f7a5..a95bca433a 100644 --- a/test/Microsoft.ML.TestFramework/EnvironmentExtensions.cs +++ b/test/Microsoft.ML.TestFramework/EnvironmentExtensions.cs @@ -22,7 +22,7 @@ public static TEnvironment AddStandardComponents(this TEnvironment env.ComponentCatalog.RegisterAssembly(typeof(TextLoader).Assembly); // ML.Data env.ComponentCatalog.RegisterAssembly(typeof(LinearPredictor).Assembly); // ML.StandardLearners env.ComponentCatalog.RegisterAssembly(typeof(OneHotEncodingTransformer).Assembly); // ML.Transforms - env.ComponentCatalog.RegisterAssembly(typeof(FastTreeBinaryPredictor).Assembly); // ML.FastTree + env.ComponentCatalog.RegisterAssembly(typeof(FastTreeBinaryModelParameters).Assembly); // ML.FastTree env.ComponentCatalog.RegisterAssembly(typeof(EnsemblePredictor).Assembly); // ML.Ensemble env.ComponentCatalog.RegisterAssembly(typeof(KMeansModelParameters).Assembly); // ML.KMeansClustering env.ComponentCatalog.RegisterAssembly(typeof(PcaPredictor).Assembly); // ML.PCA From 9fda3d63f916cc6448f95a0534606001a1ac8f4a Mon Sep 17 00:00:00 2001 From: "REDMOND\\nakazmi" Date: Thu, 6 Dec 2018 19:53:29 -0800 Subject: [PATCH 3/6] Internalize and explicitly implement IValueMapperDist --- .../Prediction/Calibrator.cs | 12 +++---- .../Trainer/EnsembleDistributionPredictor.cs | 32 +++++++++-------- .../Standard/Simple/SimpleTrainers.cs | 34 +++++++++++-------- 3 files changed, 42 insertions(+), 36 deletions(-) diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index 716f6aaf40..57b06f209b 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -221,9 +221,9 @@ public abstract class ValueMapperCalibratedPredictorBase : CalibratedPredictorBa private readonly IValueMapper _mapper; private readonly IFeatureContributionMapper _featureContribution; - public ColumnType InputType => _mapper.InputType; - public ColumnType OutputType => _mapper.OutputType; - public ColumnType DistType => NumberType.Float; + ColumnType IValueMapper.InputType => _mapper.InputType; + ColumnType IValueMapper.OutputType => _mapper.OutputType; + ColumnType IValueMapperDist.DistType => NumberType.Float; bool ICanSavePfa.CanSavePfa => (_mapper as ICanSavePfa)?.CanSavePfa == true; bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => (_mapper as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true; @@ -239,16 +239,16 @@ protected ValueMapperCalibratedPredictorBase(IHostEnvironment env, string name, _featureContribution = predictor as IFeatureContributionMapper; } - public ValueMapper GetMapper() + ValueMapper IValueMapper.GetMapper() { return _mapper.GetMapper(); } - public ValueMapper GetMapper() + ValueMapper IValueMapperDist.GetMapper() { Host.Check(typeof(TOut) == typeof(Float)); Host.Check(typeof(TDist) == typeof(Float)); - var map = GetMapper(); + var map = ((IValueMapper)this).GetMapper(); ValueMapper del = (in TIn src, ref Float score, ref Float prob) => { diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionPredictor.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionPredictor.cs index 6346abfc80..9cccc362b4 100644 --- a/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionPredictor.cs +++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionPredictor.cs @@ -24,9 +24,9 @@ namespace Microsoft.ML.Runtime.Ensemble public sealed class EnsembleDistributionPredictor : EnsemblePredictorBase, TDistPredictor, IValueMapperDist { - public const string UserName = "Ensemble Distribution Executor"; - public const string LoaderSignature = "EnsemDbExec"; - public const string RegistrationName = "EnsembleDistributionPredictor"; + internal const string UserName = "Ensemble Distribution Executor"; + internal const string LoaderSignature = "EnsemDbExec"; + internal const string RegistrationName = "EnsembleDistributionPredictor"; private static VersionInfo GetVersionInfo() { @@ -45,9 +45,11 @@ private static VersionInfo GetVersionInfo() private readonly Median _probabilityCombiner; private readonly IValueMapperDist[] _mappers; - public ColumnType InputType { get; } - public ColumnType OutputType => NumberType.Float; - public ColumnType DistType => NumberType.Float; + private readonly ColumnType _inputType; + + ColumnType IValueMapper.InputType => _inputType; + ColumnType IValueMapper.OutputType => NumberType.Float; + ColumnType IValueMapperDist.DistType => NumberType.Float; public override PredictionKind PredictionKind { get; } @@ -57,7 +59,7 @@ internal EnsembleDistributionPredictor(IHostEnvironment env, PredictionKind kind { PredictionKind = kind; _probabilityCombiner = new Median(env); - InputType = InitializeMappers(out _mappers); + _inputType = InitializeMappers(out _mappers); ComputeAveragedWeights(out _averagedWeights); } @@ -66,7 +68,7 @@ private EnsembleDistributionPredictor(IHostEnvironment env, ModelLoadContext ctx { PredictionKind = (PredictionKind)ctx.Reader.ReadInt32(); _probabilityCombiner = new Median(env); - InputType = InitializeMappers(out _mappers); + _inputType = InitializeMappers(out _mappers); ComputeAveragedWeights(out _averagedWeights); } @@ -101,7 +103,7 @@ private bool IsValid(IValueMapperDist mapper) && mapper.DistType == NumberType.Float; } - public static EnsembleDistributionPredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static EnsembleDistributionPredictor Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); @@ -119,7 +121,7 @@ private protected override void SaveCore(ModelSaveContext ctx) ctx.Writer.Write((int)PredictionKind); } - public ValueMapper GetMapper() + ValueMapper IValueMapper.GetMapper() { Host.Check(typeof(TIn) == typeof(VBuffer)); Host.Check(typeof(TOut) == typeof(Single)); @@ -132,8 +134,8 @@ public ValueMapper GetMapper() ValueMapper, Single> del = (in VBuffer src, ref Single dst) => { - if (InputType.VectorSize > 0) - Host.Check(src.Length == InputType.VectorSize); + if (_inputType.VectorSize > 0) + Host.Check(src.Length == _inputType.VectorSize); var tmp = src; Parallel.For(0, maps.Length, i => @@ -155,7 +157,7 @@ public ValueMapper GetMapper() return (ValueMapper)(Delegate)del; } - public ValueMapper GetMapper() + ValueMapper IValueMapperDist.GetMapper() { Host.Check(typeof(TIn) == typeof(VBuffer)); Host.Check(typeof(TOut) == typeof(Single)); @@ -170,8 +172,8 @@ public ValueMapper GetMapper() ValueMapper, Single, Single> del = (in VBuffer src, ref Single score, ref Single prob) => { - if (InputType.VectorSize > 0) - Host.Check(src.Length == InputType.VectorSize); + if (_inputType.VectorSize > 0) + Host.Check(src.Length == _inputType.VectorSize); var tmp = src; Parallel.For(0, maps.Length, i => diff --git a/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs b/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs index 049ce8b8d5..997ca40d3d 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs @@ -129,9 +129,11 @@ private static VersionInfo GetVersionInfo() private readonly Random _random; public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; - public ColumnType InputType { get; } - public ColumnType OutputType => NumberType.Float; - public ColumnType DistType => NumberType.Float; + + private readonly ColumnType _inputType; + ColumnType IValueMapper.InputType => _inputType; + ColumnType IValueMapper.OutputType => NumberType.Float; + ColumnType IValueMapperDist.DistType => NumberType.Float; public RandomPredictor(IHostEnvironment env, int seed) : base(env, LoaderSignature) @@ -141,7 +143,7 @@ public RandomPredictor(IHostEnvironment env, int seed) _instanceLock = new object(); _random = RandomUtils.Create(_seed); - InputType = new VectorType(NumberType.Float); + _inputType = new VectorType(NumberType.Float); } /// @@ -158,10 +160,10 @@ private RandomPredictor(IHostEnvironment env, ModelLoadContext ctx) _instanceLock = new object(); _random = RandomUtils.Create(_seed); - InputType = new VectorType(NumberType.Float); + _inputType = new VectorType(NumberType.Float); } - public static RandomPredictor Create(IHostEnvironment env, ModelLoadContext ctx) + private static RandomPredictor Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); @@ -184,7 +186,7 @@ private protected override void SaveCore(ModelSaveContext ctx) ctx.Writer.Write(_seed); } - public ValueMapper GetMapper() + ValueMapper IValueMapper.GetMapper() { Contracts.Check(typeof(TIn) == typeof(VBuffer)); Contracts.Check(typeof(TOut) == typeof(float)); @@ -193,7 +195,7 @@ public ValueMapper GetMapper() return (ValueMapper)(Delegate)del; } - public ValueMapper GetMapper() + ValueMapper IValueMapperDist.GetMapper() { Contracts.Check(typeof(TIn) == typeof(VBuffer)); Contracts.Check(typeof(TOut) == typeof(float)); @@ -371,7 +373,7 @@ public PriorPredictor(IHostEnvironment env, float prob) _prob = prob; _raw = 2 * _prob - 1; // This could be other functions -- logodds for instance - InputType = new VectorType(NumberType.Float); + _inputType = new VectorType(NumberType.Float); } private PriorPredictor(IHostEnvironment env, ModelLoadContext ctx) @@ -385,7 +387,7 @@ private PriorPredictor(IHostEnvironment env, ModelLoadContext ctx) _raw = 2 * _prob - 1; - InputType = new VectorType(NumberType.Float); + _inputType = new VectorType(NumberType.Float); } public static PriorPredictor Create(IHostEnvironment env, ModelLoadContext ctx) @@ -410,11 +412,13 @@ private protected override void SaveCore(ModelSaveContext ctx) public override PredictionKind PredictionKind { get { return PredictionKind.BinaryClassification; } } - public ColumnType InputType { get; } - public ColumnType OutputType => NumberType.Float; - public ColumnType DistType => NumberType.Float; - public ValueMapper GetMapper() + private readonly ColumnType _inputType; + ColumnType IValueMapper.InputType => _inputType; + ColumnType IValueMapper.OutputType => NumberType.Float; + ColumnType IValueMapperDist.DistType => NumberType.Float; + + ValueMapper IValueMapper.GetMapper() { Contracts.Check(typeof(TIn) == typeof(VBuffer)); Contracts.Check(typeof(TOut) == typeof(float)); @@ -423,7 +427,7 @@ public ValueMapper GetMapper() return (ValueMapper)(Delegate)del; } - public ValueMapper GetMapper() + ValueMapper IValueMapperDist.GetMapper() { Contracts.Check(typeof(TIn) == typeof(VBuffer)); Contracts.Check(typeof(TOut) == typeof(float)); From f4e35b24e4e90cec139e7ee4a2151eb98f197f01 Mon Sep 17 00:00:00 2001 From: "REDMOND\\nakazmi" Date: Mon, 10 Dec 2018 20:44:32 -0800 Subject: [PATCH 4/6] Adding public constructors and sample --- .../Dynamic/FastTreeRegression.cs | 48 +++++++++++++++++++ src/Microsoft.ML.FastTree/FastTree.cs | 2 +- .../FastTreeClassification.cs | 2 +- src/Microsoft.ML.FastTree/FastTreeRanking.cs | 2 +- .../FastTreeRegression.cs | 2 +- src/Microsoft.ML.FastTree/FastTreeTweedie.cs | 2 +- .../OlsLinearRegression.cs | 2 +- .../LightGbmBinaryTrainer.cs | 2 +- .../LightGbmRankingTrainer.cs | 2 +- .../LightGbmRegressionTrainer.cs | 2 +- .../Standard/LinearPredictor.cs | 8 ++-- 11 files changed, 61 insertions(+), 13 deletions(-) create mode 100644 docs/samples/Microsoft.ML.Samples/Dynamic/FastTreeRegression.cs diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/FastTreeRegression.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/FastTreeRegression.cs new file mode 100644 index 0000000000..ddc9105556 --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/FastTreeRegression.cs @@ -0,0 +1,48 @@ +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.Data; +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Microsoft.ML.Samples.Dynamic +{ + public class FastTreeRegressionExample + { + public static void FastTreeRegression() + { + // Create a new ML context, for ML.NET operations. It can be used for exception tracking and logging, + // as well as the source of randomness. + var ml = new MLContext(); + + // Get a small dataset as an IEnumerable and convert it to an IDataView. + var data = SamplesUtils.DatasetUtils.GetInfertData(); + var trainData = ml.CreateStreamingDataView(data); + + // Preview of the data. + // + // Age Case Education Induced Parity PooledStratum RowNum ... + // 26 1 0-5yrs 1 6 3 1 ... + // 42 1 0-5yrs 1 1 1 2 ... + // 39 1 0-5yrs 2 6 4 3 ... + // 34 1 0-5yrs 2 4 2 4 ... + // 35 1 6-11yrs 1 3 32 5 ... + + // A pipeline for concatenating the parity and induced columns together in the Features column and training a FastTreeRegression model on them. + string outputColumnName = "Features"; + var pipeline = ml.Transforms.Concatenate(outputColumnName, new[] { "Parity", "Induced" }) + .Append(ml.Regression.Trainers.FastTree(labelColumn: "Age", featureColumn: outputColumnName, numTrees: 2, numLeaves: 2, minDatapointsInLeaves: 1)); + + var model = pipeline.Fit(trainData); + + // Get the trained model parameters. + var modelParams = model.LastTransformer.Model; + + // Get the leaf and the leaf value for a row of data with Parity = 1, Induced = 1 in the first tree. + var testRow = new VBuffer(2, new[] { 1.0f, 1.0f }); + List path = default; + var leaf = modelParams.GetLeaf(0, in testRow, ref path); + var leafValue = modelParams.GetLeafValue(0, leaf); + Console.WriteLine("The leaf value in tree 0 is: " + leafValue); + } + } +} diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 7a17421853..6e70b73e21 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -2839,7 +2839,7 @@ public abstract class TreeEnsembleModelParameters : bool ICanSavePfa.CanSavePfa => true; bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true; - protected TreeEnsembleModelParameters(IHostEnvironment env, string name, TreeEnsemble trainedEnsemble, int numFeatures, string innerArgs) + public TreeEnsembleModelParameters(IHostEnvironment env, string name, TreeEnsemble trainedEnsemble, int numFeatures, string innerArgs) : base(env, name) { Host.CheckValue(trainedEnsemble, nameof(trainedEnsemble)); diff --git a/src/Microsoft.ML.FastTree/FastTreeClassification.cs b/src/Microsoft.ML.FastTree/FastTreeClassification.cs index dd3017704f..d933b5dd85 100644 --- a/src/Microsoft.ML.FastTree/FastTreeClassification.cs +++ b/src/Microsoft.ML.FastTree/FastTreeClassification.cs @@ -69,7 +69,7 @@ private static VersionInfo GetVersionInfo() protected override uint VerCategoricalSplitSerialized => 0x00010005; - internal FastTreeBinaryModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) + public FastTreeBinaryModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) { } diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs index 31497d5168..6b766787b5 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs @@ -1130,7 +1130,7 @@ private static VersionInfo GetVersionInfo() protected override uint VerCategoricalSplitSerialized => 0x00010005; - internal FastTreerankingModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) + public FastTreerankingModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) { } diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs index 811fbf951a..5a19f91e54 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs @@ -467,7 +467,7 @@ private static VersionInfo GetVersionInfo() protected override uint VerCategoricalSplitSerialized => 0x00010005; - internal FastTreeRegressionModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) + public FastTreeRegressionModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) { } diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs index 26ab590559..aa0b4e90e0 100644 --- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs +++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs @@ -470,7 +470,7 @@ private static VersionInfo GetVersionInfo() protected override uint VerCategoricalSplitSerialized => 0x00010003; - internal FastTreeTweedieModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) + public FastTreeTweedieModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) { } diff --git a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs index 5648aa29e4..1f35333844 100644 --- a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs +++ b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs @@ -738,7 +738,7 @@ public static OlsLinearRegressionPredictor Create(IHostEnvironment env, ModelLoa return new OlsLinearRegressionPredictor(env, ctx); } - private protected override void SaveSummary(TextWriter writer, RoleMappedSchema schema) + internal override void SaveSummary(TextWriter writer, RoleMappedSchema schema) { var names = default(VBuffer>); MetadataUtils.GetSlotNames(schema, RoleMappedSchema.ColumnRole.Feature, Weight.Length, ref names); diff --git a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs index cdf06e19f1..45d55f4bce 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs @@ -55,7 +55,7 @@ private static VersionInfo GetVersionInfo() protected override uint VerCategoricalSplitSerialized => 0x00010005; public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; - internal LightGbmBinaryModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) + public LightGbmBinaryModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) { } diff --git a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs index 9d2e6a45d8..fe1909c380 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs @@ -51,7 +51,7 @@ private static VersionInfo GetVersionInfo() protected override uint VerCategoricalSplitSerialized => 0x00010005; public override PredictionKind PredictionKind => PredictionKind.Ranking; - internal LightGbmRankingModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) + public LightGbmRankingModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) { } diff --git a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs index 5d2e2f1aa7..9e3868bf96 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs @@ -51,7 +51,7 @@ private static VersionInfo GetVersionInfo() protected override uint VerCategoricalSplitSerialized => 0x00010005; public override PredictionKind PredictionKind => PredictionKind.Regression; - internal LightGbmRegressionModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) + public LightGbmRegressionModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) { } diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs index 45978c9d52..1e36b5cf29 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs @@ -354,7 +354,7 @@ void ICanSaveInSourceCode.SaveAsCode(TextWriter writer, RoleMappedSchema schema) } [BestFriend] - private protected abstract void SaveSummary(TextWriter writer, RoleMappedSchema schema); + internal abstract void SaveSummary(TextWriter writer, RoleMappedSchema schema); void ICanSaveSummary.SaveSummary(TextWriter writer, RoleMappedSchema schema) => SaveSummary(writer, schema); @@ -499,7 +499,7 @@ public IParameterMixer CombineParameters(IList> mo return new LinearBinaryPredictor(Host, in weights, bias); } - private protected override void SaveSummary(TextWriter writer, RoleMappedSchema schema) + internal override void SaveSummary(TextWriter writer, RoleMappedSchema schema) { Host.CheckValue(schema, nameof(schema)); @@ -629,7 +629,7 @@ private protected override void SaveCore(ModelSaveContext ctx) ctx.SetVersionInfo(GetVersionInfo()); } - private protected override void SaveSummary(TextWriter writer, RoleMappedSchema schema) + internal override void SaveSummary(TextWriter writer, RoleMappedSchema schema) { Host.CheckValue(writer, nameof(writer)); Host.CheckValue(schema, nameof(schema)); @@ -710,7 +710,7 @@ protected override Float Score(in VBuffer src) return MathUtils.ExpSlow(base.Score(in src)); } - private protected override void SaveSummary(TextWriter writer, RoleMappedSchema schema) + internal override void SaveSummary(TextWriter writer, RoleMappedSchema schema) { Host.CheckValue(writer, nameof(writer)); Host.CheckValue(schema, nameof(schema)); From a5aa3b9eb954136e2a5c899b9fbd9578597761d2 Mon Sep 17 00:00:00 2001 From: "REDMOND\\nakazmi" Date: Mon, 10 Dec 2018 23:44:30 -0800 Subject: [PATCH 5/6] nit --- src/Microsoft.ML.FastTree/FastTreeRanking.cs | 28 +++++++++---------- .../TreeEnsemble/TreeEnsembleCombiner.cs | 2 +- .../TreeTrainersStatic.cs | 2 +- .../Training.cs | 2 +- 4 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs index 6b766787b5..efdd534577 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs @@ -33,9 +33,9 @@ "frrank", "btrank")] -[assembly: LoadableClass(typeof(FastTreerankingModelParameters), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(FastTreeRankingModelParameters), null, typeof(SignatureLoadModel), "FastTree Ranking Executor", - FastTreerankingModelParameters.LoaderSignature)] + FastTreeRankingModelParameters.LoaderSignature)] [assembly: LoadableClass(typeof(void), typeof(FastTree), null, typeof(SignatureEntryPointModule), "FastTree")] @@ -43,7 +43,7 @@ namespace Microsoft.ML.Trainers.FastTree { /// public sealed partial class FastTreeRankingTrainer - : BoostingFastTreeTrainerBase, FastTreerankingModelParameters> + : BoostingFastTreeTrainerBase, FastTreeRankingModelParameters> { internal const string LoadNameValue = "FastTreeRanking"; internal const string UserNameValue = "FastTree (Boosted Trees) Ranking"; @@ -112,7 +112,7 @@ protected override float GetMaxLabel() return GetLabelGains().Length - 1; } - private protected override FastTreerankingModelParameters TrainModelCore(TrainContext context) + private protected override FastTreeRankingModelParameters TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); var trainData = context.TrainingSet; @@ -126,7 +126,7 @@ private protected override FastTreerankingModelParameters TrainModelCore(TrainCo TrainCore(ch); FeatureCount = trainData.Schema.Feature.Type.ValueCount; } - return new FastTreerankingModelParameters(Host, TrainedEnsemble, FeatureCount, InnerArgs); + return new FastTreeRankingModelParameters(Host, TrainedEnsemble, FeatureCount, InnerArgs); } private Double[] GetLabelGains() @@ -454,10 +454,10 @@ protected override string GetTestGraphHeader() return headerBuilder.ToString(); } - protected override RankingPredictionTransformer MakeTransformer(FastTreerankingModelParameters model, Schema trainSchema) - => new RankingPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); + protected override RankingPredictionTransformer MakeTransformer(FastTreeRankingModelParameters model, Schema trainSchema) + => new RankingPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name); - public RankingPredictionTransformer Train(IDataView trainData, IDataView validationData = null) + public RankingPredictionTransformer Train(IDataView trainData, IDataView validationData = null) => TrainTransformer(trainData, validationData); protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) @@ -1104,7 +1104,7 @@ private static extern unsafe void GetDerivatives( } } - public sealed class FastTreerankingModelParameters : TreeEnsembleModelParameters + public sealed class FastTreeRankingModelParameters : TreeEnsembleModelParameters { internal const string LoaderSignature = "FastTreeRankerExec"; internal const string RegistrationName = "FastTreeRankingPredictor"; @@ -1121,7 +1121,7 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010004, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(FastTreerankingModelParameters).Assembly.FullName); + loaderAssemblyName: typeof(FastTreeRankingModelParameters).Assembly.FullName); } protected override uint VerNumFeaturesSerialized => 0x00010002; @@ -1130,12 +1130,12 @@ private static VersionInfo GetVersionInfo() protected override uint VerCategoricalSplitSerialized => 0x00010005; - public FastTreerankingModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) + public FastTreeRankingModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) { } - private FastTreerankingModelParameters(IHostEnvironment env, ModelLoadContext ctx) + private FastTreeRankingModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx, GetVersionInfo()) { } @@ -1146,9 +1146,9 @@ private protected override void SaveCore(ModelSaveContext ctx) ctx.SetVersionInfo(GetVersionInfo()); } - private static FastTreerankingModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) + private static FastTreeRankingModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) { - return new FastTreerankingModelParameters(env, ctx); + return new FastTreeRankingModelParameters(env, ctx); } public override PredictionKind PredictionKind => PredictionKind.Ranking; diff --git a/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsembleCombiner.cs b/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsembleCombiner.cs index 754c18f2f6..e621b87b8a 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsembleCombiner.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsembleCombiner.cs @@ -106,7 +106,7 @@ public IPredictor CombineModels(IEnumerable models) case PredictionKind.Regression: return new FastTreeRegressionModelParameters(_host, ensemble, featureCount, null); case PredictionKind.Ranking: - return new FastTreerankingModelParameters(_host, ensemble, featureCount, null); + return new FastTreeRankingModelParameters(_host, ensemble, featureCount, null); default: _host.Assert(false); throw _host.ExceptNotSupp(); diff --git a/src/Microsoft.ML.FastTree/TreeTrainersStatic.cs b/src/Microsoft.ML.FastTree/TreeTrainersStatic.cs index debe52f7bd..5893a06eda 100644 --- a/src/Microsoft.ML.FastTree/TreeTrainersStatic.cs +++ b/src/Microsoft.ML.FastTree/TreeTrainersStatic.cs @@ -144,7 +144,7 @@ public static Scalar FastTree(this RankingContext.RankingTrainers c int minDatapointsInLeaves = Defaults.MinDocumentsInLeaves, double learningRate = Defaults.LearningRates, Action advancedSettings = null, - Action onFit = null) + Action onFit = null) { CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeaves, learningRate, advancedSettings, onFit); diff --git a/test/Microsoft.ML.StaticPipelineTesting/Training.cs b/test/Microsoft.ML.StaticPipelineTesting/Training.cs index 1a8f78b794..47613f42a8 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/Training.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/Training.cs @@ -770,7 +770,7 @@ public void FastTreeRanking() c => (label: c.LoadFloat(0), features: c.LoadFloat(9, 14), groupId: c.LoadText(1)), separator: '\t', hasHeader: true); - FastTreerankingModelParameters pred = null; + FastTreeRankingModelParameters pred = null; var est = reader.MakeNewEstimator() .Append(r => (r.label, r.features, groupId: r.groupId.ToKey())) From 28f2fa8ced22eb50994c48e5afb554a454383952 Mon Sep 17 00:00:00 2001 From: "REDMOND\\nakazmi" Date: Tue, 11 Dec 2018 12:30:21 -0800 Subject: [PATCH 6/6] Address comments --- .../Dynamic/FastTreeRegression.cs | 16 ++++++++++------ src/Microsoft.ML.FastTree/FastTree.cs | 3 ++- .../FastTreeClassification.cs | 2 +- .../Properties/AssemblyInfo.cs | 2 ++ .../OlsLinearRegression.cs | 2 +- .../Standard/LinearPredictor.cs | 11 ++++------- .../Standard/Simple/SimpleTrainers.cs | 2 +- 7 files changed, 21 insertions(+), 17 deletions(-) diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/FastTreeRegression.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/FastTreeRegression.cs index ddc9105556..e4e3861d2b 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/FastTreeRegression.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/FastTreeRegression.cs @@ -27,22 +27,26 @@ public static void FastTreeRegression() // 34 1 0-5yrs 2 4 2 4 ... // 35 1 6-11yrs 1 3 32 5 ... - // A pipeline for concatenating the parity and induced columns together in the Features column and training a FastTreeRegression model on them. + // A pipeline for concatenating the Parity and Induced columns together in the Features column. + // We will train a FastTreeRegression model with 1 tree on these two columns to predict Age. string outputColumnName = "Features"; var pipeline = ml.Transforms.Concatenate(outputColumnName, new[] { "Parity", "Induced" }) - .Append(ml.Regression.Trainers.FastTree(labelColumn: "Age", featureColumn: outputColumnName, numTrees: 2, numLeaves: 2, minDatapointsInLeaves: 1)); + .Append(ml.Regression.Trainers.FastTree(labelColumn: "Age", featureColumn: outputColumnName, numTrees: 1, numLeaves: 2, minDatapointsInLeaves: 1)); var model = pipeline.Fit(trainData); // Get the trained model parameters. var modelParams = model.LastTransformer.Model; - // Get the leaf and the leaf value for a row of data with Parity = 1, Induced = 1 in the first tree. + // Let's see where an example with Parity = 1 and Induced = 1 would end up in the single trained tree. var testRow = new VBuffer(2, new[] { 1.0f, 1.0f }); + // Use the path object to pass to GetLeaf, which will populate path with the IDs of th nodes from root to leaf. List path = default; - var leaf = modelParams.GetLeaf(0, in testRow, ref path); - var leafValue = modelParams.GetLeafValue(0, leaf); - Console.WriteLine("The leaf value in tree 0 is: " + leafValue); + // Get the ID of the leaf this example ends up in tree 0. + var leafID = modelParams.GetLeaf(0, in testRow, ref path); + // Get the leaf value for this leaf ID in tree 0. + var leafValue = modelParams.GetLeafValue(0, leafID); + Console.WriteLine("The leaf value in tree 0 is: " + leafValue); } } } diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 6e70b73e21..c790783cc9 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -2811,7 +2811,8 @@ public abstract class TreeEnsembleModelParameters : ISingleCanSaveOnnx { //The below two properties are necessary for tree Visualizer - public TreeEnsemble TrainedEnsemble { get; } + [BestFriend] + internal TreeEnsemble TrainedEnsemble { get; } int ITreeEnsemble.NumTrees => TrainedEnsemble.NumTrees; // Inner args is used only for documentation purposes when saving comments to INI files. diff --git a/src/Microsoft.ML.FastTree/FastTreeClassification.cs b/src/Microsoft.ML.FastTree/FastTreeClassification.cs index d933b5dd85..5fd1feaa94 100644 --- a/src/Microsoft.ML.FastTree/FastTreeClassification.cs +++ b/src/Microsoft.ML.FastTree/FastTreeClassification.cs @@ -108,7 +108,7 @@ public sealed partial class FastTreeBinaryClassificationTrainer : /// /// The LoadName for the assembly containing the trainer. /// - public const string LoadNameValue = "FastTreeBinaryClassification"; + internal const string LoadNameValue = "FastTreeBinaryClassification"; internal const string UserNameValue = "FastTree (Boosted Trees) Classification"; internal const string Summary = "Uses a logit-boost boosted tree learner to perform binary classification."; internal const string ShortName = "ftc"; diff --git a/src/Microsoft.ML.FastTree/Properties/AssemblyInfo.cs b/src/Microsoft.ML.FastTree/Properties/AssemblyInfo.cs index cd27563c10..cf6d8d8d42 100644 --- a/src/Microsoft.ML.FastTree/Properties/AssemblyInfo.cs +++ b/src/Microsoft.ML.FastTree/Properties/AssemblyInfo.cs @@ -6,6 +6,8 @@ using Microsoft.ML; [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Core.Tests" + PublicKey.TestValue)] + [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.LightGBM" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Sweeper" + PublicKey.Value)] [assembly: WantsToBeBestFriends] diff --git a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs index 1f35333844..5648aa29e4 100644 --- a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs +++ b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs @@ -738,7 +738,7 @@ public static OlsLinearRegressionPredictor Create(IHostEnvironment env, ModelLoa return new OlsLinearRegressionPredictor(env, ctx); } - internal override void SaveSummary(TextWriter writer, RoleMappedSchema schema) + private protected override void SaveSummary(TextWriter writer, RoleMappedSchema schema) { var names = default(VBuffer>); MetadataUtils.GetSlotNames(schema, RoleMappedSchema.ColumnRole.Feature, Weight.Length, ref names); diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs index 1e36b5cf29..b5e1271c4c 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs @@ -354,11 +354,10 @@ void ICanSaveInSourceCode.SaveAsCode(TextWriter writer, RoleMappedSchema schema) } [BestFriend] - internal abstract void SaveSummary(TextWriter writer, RoleMappedSchema schema); + private protected abstract void SaveSummary(TextWriter writer, RoleMappedSchema schema); void ICanSaveSummary.SaveSummary(TextWriter writer, RoleMappedSchema schema) => SaveSummary(writer, schema); - [BestFriend] private protected virtual Row GetSummaryIRowOrNull(RoleMappedSchema schema) { var names = default(VBuffer>); @@ -374,12 +373,10 @@ private protected virtual Row GetSummaryIRowOrNull(RoleMappedSchema schema) Row ICanGetSummaryAsIRow.GetSummaryIRowOrNull(RoleMappedSchema schema) => GetSummaryIRowOrNull(schema); - [BestFriend] private protected virtual Row GetStatsIRowOrNull(RoleMappedSchema schema) => null; Row ICanGetSummaryAsIRow.GetStatsIRowOrNull(RoleMappedSchema schema) => GetStatsIRowOrNull(schema); - [BestFriend] private protected abstract void SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator = null); void ICanSaveInIniFormat.SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator) => SaveAsIni(writer, schema, calibrator); @@ -499,7 +496,7 @@ public IParameterMixer CombineParameters(IList> mo return new LinearBinaryPredictor(Host, in weights, bias); } - internal override void SaveSummary(TextWriter writer, RoleMappedSchema schema) + private protected override void SaveSummary(TextWriter writer, RoleMappedSchema schema) { Host.CheckValue(schema, nameof(schema)); @@ -629,7 +626,7 @@ private protected override void SaveCore(ModelSaveContext ctx) ctx.SetVersionInfo(GetVersionInfo()); } - internal override void SaveSummary(TextWriter writer, RoleMappedSchema schema) + private protected override void SaveSummary(TextWriter writer, RoleMappedSchema schema) { Host.CheckValue(writer, nameof(writer)); Host.CheckValue(schema, nameof(schema)); @@ -710,7 +707,7 @@ protected override Float Score(in VBuffer src) return MathUtils.ExpSlow(base.Score(in src)); } - internal override void SaveSummary(TextWriter writer, RoleMappedSchema schema) + private protected override void SaveSummary(TextWriter writer, RoleMappedSchema schema) { Host.CheckValue(writer, nameof(writer)); Host.CheckValue(schema, nameof(schema)); diff --git a/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs b/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs index 997ca40d3d..133a477502 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs @@ -111,7 +111,7 @@ public sealed class RandomPredictor : IValueMapperDist, ICanSaveModel { - public const string LoaderSignature = "RandomPredictor"; + internal const string LoaderSignature = "RandomPredictor"; private static VersionInfo GetVersionInfo() { return new VersionInfo(