diff --git a/src/Microsoft.ML.Core/Data/IRowToRowMapper.cs b/src/Microsoft.ML.Core/Data/IRowToRowMapper.cs new file mode 100644 index 0000000000..29de619946 --- /dev/null +++ b/src/Microsoft.ML.Core/Data/IRowToRowMapper.cs @@ -0,0 +1,54 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Data; +using System; + +namespace Microsoft.ML.Runtime.Data +{ + /// + /// This interface maps an input to an output . Typically, the output contains + /// both the input columns and new columns added by the implementing class, although some implementations may + /// return a subset of the input columns. + /// This interface is similar to , except it does not have any input role mappings, + /// so to rebind, the same input column names must be used. + /// Implementations of this interface are typically created over defined input . + /// + public interface IRowToRowMapper + { + /// + /// Mappers are defined as accepting inputs with this very specific schema. + /// + Schema InputSchema { get; } + + /// + /// Gets an instance of which describes the columns' names and types in the output generated by this mapper. + /// + Schema OutputSchema { get; } + + /// + /// Given a predicate specifying which columns are needed, return a predicate indicating which input columns are + /// needed. The domain of the function is defined over the indices of the columns of + /// for . + /// + Func GetDependencies(Func predicate); + + /// + /// Get an with the indicated active columns, based on the input . + /// The active columns are those for which returns true. Getting values on inactive + /// columns of the returned row will throw. Null predicates are disallowed. + /// + /// The of should be the same object as + /// . Implementors of this method should throw if that is not the case. Conversely, + /// the returned value must have the same schema as . + /// + /// This method creates a live connection between the input and the output . In particular, when the getters of the output are invoked, they invoke the + /// getters of the input row and base the output values on the current values of the input . + /// The output values are re-computed when requested through the getters. Also, the returned + /// will dispose when it is disposed. + /// + Row GetRow(Row input, Func active); + } +} diff --git a/src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs b/src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs index caa7b15330..ff005d157d 100644 --- a/src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs +++ b/src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs @@ -3,7 +3,6 @@ // See the LICENSE file in the project root for more information. using Microsoft.ML.Data; -using System; using System.Collections.Generic; namespace Microsoft.ML.Runtime.Data @@ -21,7 +20,8 @@ namespace Microsoft.ML.Runtime.Data /// for the output schema of the . In case the interface is implemented, /// the SimpleRow class can be used in the method. /// - public interface ISchemaBindableMapper + [BestFriend] + internal interface ISchemaBindableMapper { ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema); } @@ -30,7 +30,8 @@ public interface ISchemaBindableMapper /// This interface is used to map a schema from input columns to output columns. The should keep track /// of the input columns that are needed for the mapping. /// - public interface ISchemaBoundMapper + [BestFriend] + internal interface ISchemaBoundMapper { /// /// The that was passed to the in the binding process. @@ -56,7 +57,8 @@ public interface ISchemaBoundMapper /// /// This interface combines with . /// - public interface ISchemaBoundRowMapper : ISchemaBoundMapper, IRowToRowMapper + [BestFriend] + internal interface ISchemaBoundRowMapper : ISchemaBoundMapper, IRowToRowMapper { /// /// There are two schemas from and . @@ -64,49 +66,4 @@ public interface ISchemaBoundRowMapper : ISchemaBoundMapper, IRowToRowMapper /// new Schema OutputSchema { get; } } - - /// - /// This interface maps an input to an output . Typically, the output contains - /// both the input columns and new columns added by the implementing class, although some implementations may - /// return a subset of the input columns. - /// This interface is similar to , except it does not have any input role mappings, - /// so to rebind, the same input column names must be used. - /// Implementations of this interface are typically created over defined input . - /// - public interface IRowToRowMapper - { - /// - /// Mappers are defined as accepting inputs with this very specific schema. - /// - Schema InputSchema { get; } - - /// - /// Gets an instance of which describes the columns' names and types in the output generated by this mapper. - /// - Schema OutputSchema { get; } - - /// - /// Given a predicate specifying which columns are needed, return a predicate indicating which input columns are - /// needed. The domain of the function is defined over the indices of the columns of - /// for . - /// - Func GetDependencies(Func predicate); - - /// - /// Get an with the indicated active columns, based on the input . - /// The active columns are those for which returns true. Getting values on inactive - /// columns of the returned row will throw. Null predicates are disallowed. - /// - /// The of should be the same object as - /// . Implementors of this method should throw if that is not the case. Conversely, - /// the returned value must have the same schema as . - /// - /// This method creates a live connection between the input and the output . In particular, when the getters of the output are invoked, they invoke the - /// getters of the input row and base the output values on the current values of the input . - /// The output values are re-computed when requested through the getters. Also, the returned - /// will dispose when it is disposed. - /// - Row GetRow(Row input, Func active); - } } diff --git a/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs b/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs index 94ac81af9c..9a3f02cc24 100644 --- a/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs +++ b/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs @@ -100,7 +100,8 @@ public static ColumnInfo CreateFromIndex(Schema schema, int index) /// /// /// - public sealed class RoleMappedSchema + [BestFriend] + internal sealed class RoleMappedSchema { private const string FeatureString = "Feature"; private const string LabelString = "Label"; diff --git a/src/Microsoft.ML.Data/Commands/SavePredictorCommand.cs b/src/Microsoft.ML.Data/Commands/SavePredictorCommand.cs index 633f871dd0..c537cbd3ab 100644 --- a/src/Microsoft.ML.Data/Commands/SavePredictorCommand.cs +++ b/src/Microsoft.ML.Data/Commands/SavePredictorCommand.cs @@ -114,7 +114,8 @@ private Stream CreateStrm(IFileHandle file) } } - public static class SavePredictorUtils + [BestFriend] + internal static class SavePredictorUtils { public static void SavePredictor(IHostEnvironment env, Stream modelStream, Stream binaryModelStream = null, Stream summaryModelStream = null, Stream textModelStream = null, Stream iniModelStream = null, Stream codeModelStream = null) diff --git a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs index 89c7572b63..fe289fa4f7 100644 --- a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs +++ b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs @@ -34,7 +34,8 @@ public interface IDataScorerTransform : IDataTransform, ITransformTemplate /// This parameter holds a snapshot of the role mapped training schema as /// it existed at the point when was trained, or null if it not /// available for some reason - public delegate void SignatureDataScorer(IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema); + [BestFriend] + internal delegate void SignatureDataScorer(IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema); public delegate void SignatureBindableMapper(IPredictor predictor); diff --git a/src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs b/src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs index 44356059e1..8364b17060 100644 --- a/src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs +++ b/src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs @@ -16,7 +16,8 @@ namespace Microsoft.ML.Runtime.Internal.Internallearn { - public abstract class FeatureNameCollection : IEnumerable + [BestFriend] + internal abstract class FeatureNameCollection : IEnumerable { private sealed class FeatureNameCollectionSchema : ISchema { diff --git a/src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs b/src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs index 312603b949..405dd31296 100644 --- a/src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs +++ b/src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs @@ -143,7 +143,8 @@ internal interface ICanGetSummaryAsIRow Row GetStatsIRowOrNull(RoleMappedSchema schema); } - public interface ICanGetSummaryAsIDataView + [BestFriend] + internal interface ICanGetSummaryAsIDataView { IDataView GetSummaryDataView(RoleMappedSchema schema); } diff --git a/src/Microsoft.ML.Data/Dirty/PredictorUtils.cs b/src/Microsoft.ML.Data/Dirty/PredictorUtils.cs index bcd5492f59..46e6a859f5 100644 --- a/src/Microsoft.ML.Data/Dirty/PredictorUtils.cs +++ b/src/Microsoft.ML.Data/Dirty/PredictorUtils.cs @@ -9,10 +9,11 @@ namespace Microsoft.ML.Runtime.Internal.Internallearn { - public static class PredictorUtils + [BestFriend] + internal static class PredictorUtils { /// - /// Save the model summary + /// Save the model summary. /// public static void SaveSummary(IChannel ch, IPredictor predictor, RoleMappedSchema schema, TextWriter writer) { diff --git a/src/Microsoft.ML.Data/EntryPoints/SummarizePredictor.cs b/src/Microsoft.ML.Data/EntryPoints/SummarizePredictor.cs index 49e5acdca5..12cee16b3b 100644 --- a/src/Microsoft.ML.Data/EntryPoints/SummarizePredictor.cs +++ b/src/Microsoft.ML.Data/EntryPoints/SummarizePredictor.cs @@ -43,7 +43,8 @@ public static CommonOutputs.SummaryOutput Summarize(IHostEnvironment env, Summar return output; } - public static IDataView GetSummaryAndStats(IHostEnvironment env, IPredictor predictor, RoleMappedSchema schema, out IDataView stats) + [BestFriend] + internal static IDataView GetSummaryAndStats(IHostEnvironment env, IPredictor predictor, RoleMappedSchema schema, out IDataView stats) { var calibrated = predictor as CalibratedPredictorBase; while (calibrated != null) diff --git a/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs index a42d449aef..905bdaaa13 100644 --- a/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs @@ -91,7 +91,7 @@ public AnomalyDetectionEvaluator(IHostEnvironment env, Arguments args) _aucCount = args.MaxAucExamples; } - protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) + private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) { var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); var t = score.Type; @@ -103,7 +103,7 @@ protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) throw Host.Except("Label column '{0}' has type '{1}' but must be R4 or a 2-value key", schema.Label.Name, t).MarkSensitive(MessageSensitivity.Schema); } - protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) + private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) { return new Aggregator(Host, _aucCount, _numTopResults, _k, _p, _streaming, schema.Name == null ? -1 : schema.Name.Index, stratName); } @@ -124,7 +124,7 @@ public override IEnumerable GetOverallMetricColumns() yield return new MetricColumn("ThreshAtNumPos", OverallMetrics.ThreshAtNumPos, MetricColumn.Objective.Info, canBeWeighted: false); } - protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries, + private protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries, out Action, Aggregator> addAgg, out Func> consolidate) { var stratCol = new List(); @@ -498,7 +498,7 @@ private void FinishOtherMetrics() } } - public override void InitializeNextPass(Row row, RoleMappedSchema schema) + internal override void InitializeNextPass(Row row, RoleMappedSchema schema) { Host.Assert(!_streaming && PassNum < 2 || PassNum < 1); Host.AssertValue(schema.Label); @@ -621,7 +621,7 @@ public AnomalyDetectionMamlEvaluator(IHostEnvironment env, Arguments args) _evaluator = new AnomalyDetectionEvaluator(Host, evalArgs); } - protected override void PrintFoldResultsCore(IChannel ch, Dictionary metrics) + private protected override void PrintFoldResultsCore(IChannel ch, Dictionary metrics) { IDataView top; if (!metrics.TryGetValue(AnomalyDetectionEvaluator.TopKResults, out top)) @@ -732,7 +732,7 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary GetPerInstanceColumnsToSave(RoleMappedSchema schema) + private protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) { Host.CheckValue(schema, nameof(schema)); Host.CheckValue(schema.Label, nameof(schema), "Data must contain a label column"); diff --git a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs index 2af7447f91..4468ba000d 100644 --- a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs @@ -123,7 +123,7 @@ public BinaryClassifierEvaluator(IHostEnvironment env, Arguments args) _auPrcCount = args.NumAuPrcExamples; } - protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) + private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) { var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); var host = Host.SchemaSensitive(); @@ -136,7 +136,7 @@ protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) throw host.SchemaSensitive().Except("Label column '{0}' has type '{1}' but must be R4, R8, BL or a 2-value key", schema.Label.Name, t); } - protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema) + private protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema) { var prob = schema.GetColumns(MetadataUtils.Const.ScoreValueKind.Probability); var host = Host.SchemaSensitive(); @@ -155,7 +155,7 @@ protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema) } // Add also the probability column. - protected override Func GetActiveColsCore(RoleMappedSchema schema) + private protected override Func GetActiveColsCore(RoleMappedSchema schema) { var pred = base.GetActiveColsCore(schema); var prob = schema.GetColumns(MetadataUtils.Const.ScoreValueKind.Probability); @@ -163,7 +163,7 @@ protected override Func GetActiveColsCore(RoleMappedSchema schema) return i => Utils.Size(prob) > 0 && i == prob[0].Index || pred(i); } - protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) + private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) { var classNames = GetClassNames(schema); return new Aggregator(Host, classNames, schema.Weight != null, _aucCount, _auPrcCount, _threshold, _useRaw, _prCount, stratName); @@ -215,7 +215,7 @@ public override IEnumerable GetOverallMetricColumns() yield return new MetricColumn("AUPRC", AuPrc); } - protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries, + private protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries, out Action, Aggregator> addAgg, out Func> consolidate) { var stratCol = new List(); @@ -609,7 +609,7 @@ public Aggregator(IHostEnvironment env, ReadOnlyMemory[] classNames, bool } } - public override void InitializeNextPass(Row row, RoleMappedSchema schema) + internal override void InitializeNextPass(Row row, RoleMappedSchema schema) { Host.AssertValue(schema.Label); Host.Assert(PassNum < 1); @@ -1172,7 +1172,7 @@ public BinaryClassifierMamlEvaluator(IHostEnvironment env, Arguments args) _evaluator = new BinaryClassifierEvaluator(Host, evalArgs); } - protected override IEnumerable> GetInputColumnRolesCore(RoleMappedSchema schema) + private protected override IEnumerable> GetInputColumnRolesCore(RoleMappedSchema schema) { var cols = base.GetInputColumnRolesCore(schema); @@ -1187,7 +1187,7 @@ public BinaryClassifierMamlEvaluator(IHostEnvironment env, Arguments args) return cols; } - protected override void PrintFoldResultsCore(IChannel ch, Dictionary metrics) + private protected override void PrintFoldResultsCore(IChannel ch, Dictionary metrics) { ch.AssertValue(metrics); @@ -1240,12 +1240,12 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary[] metrics) + private protected override void PrintAdditionalMetricsCore(IChannel ch, Dictionary[] metrics) { ch.AssertNonEmpty(metrics); @@ -1479,7 +1479,7 @@ private void SavePrPlots(List prList) return avgPoints; } #endif - protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) + private protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) { Host.CheckValue(schema, nameof(schema)); Host.CheckParam(schema.Label != null, nameof(schema), "Schema must contain a label column"); diff --git a/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs index bd8fb347f7..fa60665447 100644 --- a/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs @@ -94,7 +94,7 @@ public ClusteringMetrics Evaluate(IDataView data, string score, string label = n return result; } - protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) + private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) { ColumnType type; if (schema.Label != null && (type = schema.Label.Type) != NumberType.Float && type.KeyCount == 0) @@ -110,7 +110,7 @@ protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) throw Host.Except("Scores column '{0}' type must be a float vector of known size", score.Name); } - protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema) + private protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema) { if (_calculateDbi) { @@ -124,7 +124,7 @@ protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema) } } - protected override Func GetActiveColsCore(RoleMappedSchema schema) + private protected override Func GetActiveColsCore(RoleMappedSchema schema) { var pred = base.GetActiveColsCore(schema); // We also need the features column for dbi calculation. @@ -132,7 +132,7 @@ protected override Func GetActiveColsCore(RoleMappedSchema schema) return i => _calculateDbi && i == schema.Feature.Index || pred(i); } - protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) + private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) { Host.AssertValue(schema); Host.Assert(!_calculateDbi || (schema.Feature != null && schema.Feature.Type.IsKnownSizeVector)); @@ -156,7 +156,7 @@ public override IEnumerable GetOverallMetricColumns() yield return new MetricColumn("DBI", Dbi, MetricColumn.Objective.Minimize); } - protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries, + private protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries, out Action, Aggregator> addAgg, out Func> consolidate) { var stratCol = new List(); @@ -484,7 +484,7 @@ private void ProcessRowSecondPass() WeightedCounters.UpdateSecondPass(in _features, _indicesArr); } - public override void InitializeNextPass(Row row, RoleMappedSchema schema) + internal override void InitializeNextPass(Row row, RoleMappedSchema schema) { AssertValid(assertGetters: false); @@ -796,7 +796,7 @@ public ClusteringMamlEvaluator(IHostEnvironment env, Arguments args) _evaluator = new ClusteringEvaluator(Host, evalArgs); } - protected override IEnumerable> GetInputColumnRolesCore(RoleMappedSchema schema) + private protected override IEnumerable> GetInputColumnRolesCore(RoleMappedSchema schema) { foreach (var col in base.GetInputColumnRolesCore(schema)) { @@ -816,7 +816,7 @@ public ClusteringMamlEvaluator(IHostEnvironment env, Arguments args) } // Clustering evaluator adds three per-instance columns: "ClusterId", "Top clusters" and "Top cluster scores". - protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) + private protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) { Host.CheckValue(schema, nameof(schema)); @@ -830,7 +830,7 @@ protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSch yield return ClusteringPerInstanceEvaluator.SortedClusterScores; } - protected override IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMappedSchema schema) + private protected override IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMappedSchema schema) { // Wrap with a DropSlots transform to pick only the first _numTopClusters slots. if (perInst.Schema.TryGetColumnIndex(ClusteringPerInstanceEvaluator.SortedClusters, out int index)) diff --git a/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs b/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs index ab1330c31b..287d882645 100644 --- a/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs +++ b/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs @@ -22,7 +22,8 @@ public abstract partial class EvaluatorBase : IEvaluator { protected readonly IHost Host; - protected EvaluatorBase(IHostEnvironment env, string registrationName) + [BestFriend] + private protected EvaluatorBase(IHostEnvironment env, string registrationName) { Contracts.CheckValue(env, nameof(env)); Host = env.Register(registrationName); @@ -44,7 +45,8 @@ Dictionary IEvaluator.Evaluate(RoleMappedData data) /// Checks the column types of the evaluator's input columns. The base class implementation checks only the type /// of the weight column, and all other columns should be checked by the deriving classes in . /// - protected void CheckColumnTypes(RoleMappedSchema schema) + [BestFriend] + private protected void CheckColumnTypes(RoleMappedSchema schema) { // Check the weight column type. if (schema.Weight != null) @@ -60,13 +62,15 @@ protected void CheckColumnTypes(RoleMappedSchema schema) /// Access the label column with the property, and the score column with the /// or methods. /// - protected abstract void CheckScoreAndLabelTypes(RoleMappedSchema schema); + [BestFriend] + private protected abstract void CheckScoreAndLabelTypes(RoleMappedSchema schema); /// /// Check the types of any other columns needed by the evaluator. Only override if the evaluator uses /// columns other than label, score and weight. /// - protected virtual void CheckCustomColumnTypesCore(RoleMappedSchema schema) + [BestFriend] + private protected virtual void CheckCustomColumnTypesCore(RoleMappedSchema schema) { } @@ -84,7 +88,8 @@ private Func GetActiveCols(RoleMappedSchema schema) /// and the stratification columns. /// Override if other input columns need to be activated. /// - protected virtual Func GetActiveColsCore(RoleMappedSchema schema) + [BestFriend] + private protected virtual Func GetActiveColsCore(RoleMappedSchema schema) { var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); var label = schema.Label == null ? -1 : schema.Label.Index; @@ -116,7 +121,8 @@ private AggregatorDictionaryBase[] GetAggregatorDictionaries(RoleMappedSchema sc return list.ToArray(); } - protected abstract TAgg GetAggregatorCore(RoleMappedSchema schema, string stratName); + [BestFriend] + private protected abstract TAgg GetAggregatorCore(RoleMappedSchema schema, string stratName); // This method does as many passes over the data as needed by the evaluator, and computes the metrics, outputting the // results in a dictionary from the metric kind (overal/per-fold/confusion matrix/PR-curves etc.), to a data view containing @@ -192,10 +198,12 @@ private Dictionary ProcessData(IDataView data, RoleMappedSche /// is called after has been called on all the aggregators, and it returns /// the dictionary of metric data views. /// - protected abstract void GetAggregatorConsolidationFuncs(TAgg aggregator, AggregatorDictionaryBase[] dictionaries, + [BestFriend] + private protected abstract void GetAggregatorConsolidationFuncs(TAgg aggregator, AggregatorDictionaryBase[] dictionaries, out Action, TAgg> addAgg, out Func> consolidate); - protected ValueGetter>> GetKeyValueGetter(AggregatorDictionaryBase[] dictionaries) + [BestFriend] + private protected ValueGetter>> GetKeyValueGetter(AggregatorDictionaryBase[] dictionaries) { if (Utils.Size(dictionaries) == 0) return null; @@ -236,7 +244,8 @@ public abstract class AggregatorBase protected int PassNum; - protected AggregatorBase(IHostEnvironment env, string stratName) + [BestFriend] + private protected AggregatorBase(IHostEnvironment env, string stratName) { Contracts.AssertValue(env); Host = env.Register("Aggregator"); @@ -256,7 +265,8 @@ public bool Start() /// /// This method should get the getters of the new IRow that are needed for the next pass. /// - public abstract void InitializeNextPass(Row row, RoleMappedSchema schema); + [BestFriend] + internal abstract void InitializeNextPass(Row row, RoleMappedSchema schema); /// /// Call the getters once, and process the input as necessary. @@ -327,15 +337,15 @@ protected virtual List GetWarningsCore() // When a new value is encountered, it uses a callback for creating a new aggregator. protected abstract class AggregatorDictionaryBase { - protected Row Row; - protected readonly Func CreateAgg; - protected readonly RoleMappedSchema Schema; + private protected Row Row; + private protected readonly Func CreateAgg; + private protected readonly RoleMappedSchema Schema; public string ColName { get; } public abstract int Count { get; } - protected AggregatorDictionaryBase(RoleMappedSchema schema, string stratCol, Func createAgg) + private protected AggregatorDictionaryBase(RoleMappedSchema schema, string stratCol, Func createAgg) { Contracts.AssertValue(schema); Contracts.AssertNonWhiteSpace(stratCol); @@ -351,7 +361,7 @@ protected AggregatorDictionaryBase(RoleMappedSchema schema, string stratCol, Fun /// public abstract void Reset(Row row); - public static AggregatorDictionaryBase Create(RoleMappedSchema schema, string stratCol, ColumnType stratType, + internal static AggregatorDictionaryBase Create(RoleMappedSchema schema, string stratCol, ColumnType stratType, Func createAgg) { Contracts.AssertNonWhiteSpace(stratCol); @@ -438,7 +448,8 @@ public override IEnumerable GetAll() public abstract class RowToRowEvaluatorBase : EvaluatorBase where TAgg : EvaluatorBase.AggregatorBase { - protected RowToRowEvaluatorBase(IHostEnvironment env, string registrationName) + [BestFriend] + private protected RowToRowEvaluatorBase(IHostEnvironment env, string registrationName) : base(env, registrationName) { } diff --git a/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs index 34e8bda9f4..84221e8a6a 100644 --- a/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs @@ -74,19 +74,26 @@ public abstract class ArgumentsBase : EvaluateInputBase public string[] StratColumn; } - public static RoleMappedSchema.ColumnRole Strat = "Strat"; - protected readonly IHost Host; + internal static RoleMappedSchema.ColumnRole Strat = "Strat"; + [BestFriend] + private protected readonly IHost Host; - protected readonly string ScoreColumnKind; - protected readonly string ScoreCol; - protected readonly string LabelCol; - protected readonly string WeightCol; - protected readonly string[] StratCols; + [BestFriend] + private protected readonly string ScoreColumnKind; + [BestFriend] + private protected readonly string ScoreCol; + [BestFriend] + private protected readonly string LabelCol; + [BestFriend] + private protected readonly string WeightCol; + [BestFriend] + private protected readonly string[] StratCols; [BestFriend] private protected abstract IEvaluator Evaluator { get; } - protected MamlEvaluatorBase(ArgumentsBase args, IHostEnvironment env, string scoreColumnKind, string registrationName) + [BestFriend] + private protected MamlEvaluatorBase(ArgumentsBase args, IHostEnvironment env, string scoreColumnKind, string registrationName) { Contracts.CheckValue(env, nameof(env)); Host = env.Register(registrationName); @@ -103,7 +110,8 @@ Dictionary IEvaluator.Evaluate(RoleMappedData data) return Evaluator.Evaluate(data); } - protected IEnumerable> GetInputColumnRoles(RoleMappedSchema schema, bool needStrat = false, bool needName = false) + [BestFriend] + private protected IEnumerable> GetInputColumnRoles(RoleMappedSchema schema, bool needStrat = false, bool needName = false) { Host.CheckValue(schema, nameof(schema)); @@ -122,7 +130,8 @@ Dictionary IEvaluator.Evaluate(RoleMappedData data) /// The base class ipmlementation gets the score column, the label column (if exists) and the weight column (if exists). /// Override if additional columns are needed. /// - protected virtual IEnumerable> GetInputColumnRolesCore(RoleMappedSchema schema) + [BestFriend] + private protected virtual IEnumerable> GetInputColumnRolesCore(RoleMappedSchema schema) { // Get the score column information. var scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, schema.Schema, ScoreCol, nameof(ArgumentsBase.ScoreColumn), @@ -154,7 +163,8 @@ void IMamlEvaluator.PrintFoldResults(IChannel ch, Dictionary /// This method simply prints the overall metrics using EvaluateUtils.PrintConfusionMatrixAndPerFoldResults. /// Override if something else is needed. /// - protected virtual void PrintFoldResultsCore(IChannel ch, Dictionary metrics) + [BestFriend] + private protected virtual void PrintFoldResultsCore(IChannel ch, Dictionary metrics) { ch.AssertValue(ch); ch.AssertValue(metrics); @@ -177,12 +187,14 @@ IDataView IMamlEvaluator.GetOverallResults(params IDataView[] metrics) return GetOverallResultsCore(overall); } - protected virtual IDataView CombineOverallMetricsCore(IDataView[] metrics) + [BestFriend] + private protected virtual IDataView CombineOverallMetricsCore(IDataView[] metrics) { return EvaluateUtils.ConcatenateOverallMetrics(Host, metrics); } - protected virtual IDataView GetOverallResultsCore(IDataView overall) + [BestFriend] + private protected virtual IDataView GetOverallResultsCore(IDataView overall) { return overall; } @@ -198,7 +210,8 @@ void IMamlEvaluator.PrintAdditionalMetrics(IChannel ch, params Dictionary - protected virtual void PrintAdditionalMetricsCore(IChannel ch, Dictionary[] metrics) + [BestFriend] + private protected virtual void PrintAdditionalMetricsCore(IChannel ch, Dictionary[] metrics) { } @@ -259,7 +272,8 @@ private IDataView WrapPerInstance(RoleMappedData perInst) /// It should be overridden only if additional processing is needed, such as dropping slots in the "top k scores" column /// in the multi-class case. /// - protected virtual IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMappedSchema schema) + [BestFriend] + private protected virtual IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMappedSchema schema) { return perInst; } @@ -276,6 +290,7 @@ IDataView IMamlEvaluator.GetPerInstanceDataViewToSave(RoleMappedData perInstance /// the columns generated by the corresponding , or any of the input columns used by /// it. The Name and Weight columns should not be included, since the base class includes them automatically. /// - protected abstract IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema); + [BestFriend] + private protected abstract IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema); } } diff --git a/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs index e5426f2b04..de7be6ec84 100644 --- a/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs @@ -58,7 +58,7 @@ public enum Metrics LogLossReduction, } - public const string LoadName = "MultiClassClassifierEvaluator"; + internal const string LoadName = "MultiClassClassifierEvaluator"; private readonly int? _outputTopKAcc; private readonly bool _names; @@ -72,7 +72,7 @@ public MultiClassClassifierEvaluator(IHostEnvironment env, Arguments args) _names = args.Names; } - protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) + private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) { var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); var t = score.Type; @@ -84,7 +84,7 @@ protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) throw Host.Except("Label column '{0}' has type {1} but must be a float or a known-cardinality key", schema.Label.Name, t); } - protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) + private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) { var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); Host.Assert(score.Type.VectorSize > 0); @@ -137,7 +137,7 @@ public override IEnumerable GetOverallMetricColumns() yield return new MetricColumn("LogLossReduction", LogLossReduction); } - protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries, + private protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries, out Action, Aggregator> addAgg, out Func> consolidate) { var stratCol = new List(); @@ -387,7 +387,7 @@ public Aggregator(IHostEnvironment env, ReadOnlyMemory[] classNames, int s ClassNames = classNames; } - public override void InitializeNextPass(Row row, RoleMappedSchema schema) + internal override void InitializeNextPass(Row row, RoleMappedSchema schema) { Host.Assert(PassNum < 1); Host.AssertValue(schema.Label); @@ -871,7 +871,7 @@ public MultiClassMamlEvaluator(IHostEnvironment env, Arguments args) _evaluator = new MultiClassClassifierEvaluator(Host, evalArgs); } - protected override void PrintFoldResultsCore(IChannel ch, Dictionary metrics) + private protected override void PrintFoldResultsCore(IChannel ch, Dictionary metrics) { Host.AssertValue(metrics); @@ -901,7 +901,7 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary(); @@ -940,7 +940,7 @@ protected override IDataView CombineOverallMetricsCore(IDataView[] metrics) return base.CombineOverallMetricsCore(views); } - protected override IDataView GetOverallResultsCore(IDataView overall) + private protected override IDataView GetOverallResultsCore(IDataView overall) { // Change the name of the Top-k-accuracy column. if (_outputTopKAcc != null) @@ -978,7 +978,7 @@ public override IEnumerable GetOverallMetricColumns() yield return new MetricColumn("LogLossReduction", MultiClassClassifierEvaluator.LogLossReduction); } - protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) + private protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) { Host.CheckValue(schema, nameof(schema)); Host.CheckParam(schema.Label != null, nameof(schema), "Schema must contain a label column"); @@ -994,7 +994,7 @@ protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSch } // Multi-class evaluator adds four per-instance columns: "Assigned", "Top scores", "Top classes" and "Log-loss". - protected override IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMappedSchema schema) + private protected override IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMappedSchema schema) { // If the label column is a key without key values, convert it to I8, just for saving the per-instance // text file, since if there are different key counts the columns cannot be appended. diff --git a/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs index 86232f647f..2f1e8769b8 100644 --- a/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs @@ -56,7 +56,7 @@ private protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchem return new MultiOutputRegressionPerInstanceEvaluator(Host, schema.Schema, scoreInfo.Name, schema.Label.Name); } - protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) + private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) { var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); var t = score.Type; @@ -68,7 +68,7 @@ protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) throw Host.Except("Label column '{0}' has type '{1}' but must be a known-size vector of R4 or R8", schema.Label.Name, t); } - protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) + private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) { var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); Host.Assert(score.Type.VectorSize > 0); @@ -92,7 +92,7 @@ public override IEnumerable GetOverallMetricColumns() groupName: "label", nameFormat: string.Format("{0} (Label_{{0}}", PerLabelLoss)); } - protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries, + private protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries, out Action, Aggregator> addAgg, out Func> consolidate) { var stratCol = new List(); @@ -299,7 +299,7 @@ public Aggregator(IHostEnvironment env, IRegressionLoss lossFunction, int size, WeightedCounters = Weighted ? new Counters(lossFunction, _size) : null; } - public override void InitializeNextPass(Row row, RoleMappedSchema schema) + internal override void InitializeNextPass(Row row, RoleMappedSchema schema) { Contracts.Assert(PassNum < 1); Contracts.AssertValue(schema.Label); @@ -636,7 +636,7 @@ public MultiOutputRegressionMamlEvaluator(IHostEnvironment env, Arguments args) _evaluator = new MultiOutputRegressionEvaluator(Host, evalArgs); } - protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) + private protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) { Host.CheckValue(schema, nameof(schema)); Host.CheckParam(schema.Label != null, nameof(schema), "Schema must contain a label column"); @@ -658,7 +658,7 @@ protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSch } // The multi-output regression evaluator prints only the per-label metrics for each fold. - protected override void PrintFoldResultsCore(IChannel ch, Dictionary metrics) + private protected override void PrintFoldResultsCore(IChannel ch, Dictionary metrics) { IDataView fold; if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out fold)) diff --git a/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs index d91dc943ac..e7f342d1a1 100644 --- a/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs @@ -53,7 +53,7 @@ private protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchem return new QuantileRegressionPerInstanceEvaluator(Host, schema.Schema, scoreInfo.Name, schema.Label.Name, scoreSize, quantiles); } - protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) + private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) { var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); var t = score.Type; @@ -68,7 +68,7 @@ protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) throw Host.Except("Label column '{0}' has type '{1}' but must be R4", schema.Label.Name, t); } - protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) + private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) { var scoreInfo = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); var t = scoreInfo.Type; @@ -487,7 +487,7 @@ public QuantileRegressionMamlEvaluator(IHostEnvironment env, Arguments args) _evaluator = new QuantileRegressionEvaluator(Host, evalArgs); } - protected override void PrintFoldResultsCore(IChannel ch, Dictionary metrics) + private protected override void PrintFoldResultsCore(IChannel ch, Dictionary metrics) { ch.AssertValue(metrics); @@ -505,7 +505,7 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary GetOverallMetricColumns() yield return new MetricColumn("RSquared", QuantileRegressionEvaluator.RSquared); } - protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) + private protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) { Host.CheckValue(schema, nameof(schema)); Host.CheckParam(schema.Label != null, nameof(schema), "Schema must contain a label column"); diff --git a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs index 5523057462..6872890699 100644 --- a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs @@ -84,7 +84,7 @@ public RankerEvaluator(IHostEnvironment env, Arguments args) _labelGains = labelGains.ToArray(); } - protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) + private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) { var t = schema.Label.Type; if (t != NumberType.Float && !t.IsKey) @@ -100,7 +100,7 @@ protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) } } - protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema) + private protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema) { var t = schema.Group.Type; if (!t.IsKey) @@ -112,13 +112,13 @@ protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema) } // Add also the group column. - protected override Func GetActiveColsCore(RoleMappedSchema schema) + private protected override Func GetActiveColsCore(RoleMappedSchema schema) { var pred = base.GetActiveColsCore(schema); return i => i == schema.Group.Index || pred(i); } - protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) + private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) { return new Aggregator(Host, _labelGains, _truncationLevel, _groupSummary, schema.Weight != null, stratName); } @@ -147,7 +147,7 @@ public override IEnumerable GetOverallMetricColumns() groupName: "at", nameFormat: string.Format("{0} @{{0}}", MaxDcg)); } - protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries, + private protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries, out Action, Aggregator> addAgg, out Func> consolidate) { var stratCol = new List(); @@ -440,7 +440,7 @@ public Aggregator(IHostEnvironment env, Double[] labelGains, int truncationLevel GroupId = new List>(); } - public override void InitializeNextPass(Row row, RoleMappedSchema schema) + internal override void InitializeNextPass(Row row, RoleMappedSchema schema) { Contracts.Assert(PassNum < 1); Contracts.AssertValue(schema.Label); @@ -882,14 +882,14 @@ public RankerMamlEvaluator(IHostEnvironment env, Arguments args) _groupIdCol = args.GroupIdColumn; } - protected override IEnumerable> GetInputColumnRolesCore(RoleMappedSchema schema) + private protected override IEnumerable> GetInputColumnRolesCore(RoleMappedSchema schema) { var cols = base.GetInputColumnRolesCore(schema); var groupIdCol = EvaluateUtils.GetColName(_groupIdCol, schema.Group, DefaultColumnNames.GroupId); return cols.Prepend(RoleMappedSchema.ColumnRole.Group.Bind(groupIdCol)); } - protected override void PrintAdditionalMetricsCore(IChannel ch, Dictionary[] metrics) + private protected override void PrintAdditionalMetricsCore(IChannel ch, Dictionary[] metrics) { ch.AssertNonEmpty(metrics); @@ -929,7 +929,7 @@ private bool TryGetGroupSummaryMetrics(Dictionary[] metrics, return true; } - protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) + private protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) { Host.CheckValue(schema, nameof(schema)); Host.CheckValue(schema.Label, nameof(schema), "Data must contain a label column"); diff --git a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs index 48d94e3159..c4004a494e 100644 --- a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs @@ -52,7 +52,7 @@ public RegressionEvaluator(IHostEnvironment env, Arguments args) { } - protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) + private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) { var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); var t = score.Type; @@ -64,7 +64,7 @@ protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) throw Host.Except("Label column '{0}' has type '{1}' but must be R4", schema.Label.Name, t); } - protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) + private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) { return new Aggregator(Host, LossFunction, schema.Weight != null, stratName); } @@ -350,7 +350,7 @@ public RegressionMamlEvaluator(IHostEnvironment env, Arguments args) _evaluator = new RegressionEvaluator(Host, evalArgs); } - protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) + private protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) { Host.CheckValue(schema, nameof(schema)); Host.CheckParam(schema.Label != null, nameof(schema), "Schema must contain a label column"); diff --git a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluatorBase.cs b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluatorBase.cs index 8343e9765e..c10cb87903 100644 --- a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluatorBase.cs +++ b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluatorBase.cs @@ -38,12 +38,13 @@ protected RegressionLossEvaluatorBase(ArgumentsBase args, IHostEnvironment env, public abstract class RegressionEvaluatorBase : RegressionLossEvaluatorBase where TAgg : RegressionEvaluatorBase.RegressionAggregatorBase { - protected RegressionEvaluatorBase(ArgumentsBase args, IHostEnvironment env, string registrationName) + [BestFriend] + private protected RegressionEvaluatorBase(ArgumentsBase args, IHostEnvironment env, string registrationName) : base(args, env, registrationName) { } - protected override void GetAggregatorConsolidationFuncs(TAgg aggregator, AggregatorDictionaryBase[] dictionaries, + private protected override void GetAggregatorConsolidationFuncs(TAgg aggregator, AggregatorDictionaryBase[] dictionaries, out Action, TAgg> addAgg, out Func> consolidate) { var stratCol = new List(); @@ -184,7 +185,8 @@ public void Update(ref TScore score, float label, float weight, ref TMetrics los public abstract CountersBase UnweightedCounters { get; } public abstract CountersBase WeightedCounters { get; } - protected RegressionAggregatorBase(IHostEnvironment env, IRegressionLoss lossFunction, bool weighted, string stratName) + [BestFriend] + private protected RegressionAggregatorBase(IHostEnvironment env, IRegressionLoss lossFunction, bool weighted, string stratName) : base(env, stratName) { Host.AssertValue(lossFunction); @@ -192,7 +194,7 @@ protected RegressionAggregatorBase(IHostEnvironment env, IRegressionLoss lossFun Weighted = weighted; } - public override void InitializeNextPass(Row row, RoleMappedSchema schema) + internal override void InitializeNextPass(Row row, RoleMappedSchema schema) { Contracts.Assert(PassNum < 1); Contracts.AssertValue(schema.Label); diff --git a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs index ae02024084..ee800ea4a0 100644 --- a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs @@ -125,7 +125,8 @@ private static ISchemaBoundMapper WrapCore(IHostEnvironment env, ISchemaBound return MultiClassClassifierScorer.LabelNameBindableMapper.CreateBound(env, (ISchemaBoundRowMapper)mapper, type as VectorType, getter, MetadataUtils.Kinds.TrainingLabelValues, CanWrap); } - public BinaryClassifierScorer(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema) + [BestFriend] + internal BinaryClassifierScorer(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema) : base(args, env, data, WrapIfNeeded(env, mapper, trainSchema), trainSchema, RegistrationName, MetadataUtils.Const.ScoreColumnKind.BinaryClassification, Contracts.CheckRef(args, nameof(args)).ThresholdColumn, OutputTypeMatches, GetPredColType) { @@ -168,7 +169,7 @@ public static BinaryClassifierScorer Create(IHostEnvironment env, ModelLoadConte return h.Apply("Loading Model", ch => new BinaryClassifierScorer(h, ctx, input)); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.SetVersionInfo(GetVersionInfo()); @@ -269,7 +270,7 @@ private void GetPredictedLabelCoreAsKey(Float score, ref uint value) value = (uint)(score > _threshold ? 2 : score <= _threshold ? 1 : 0); } - protected override JToken PredictedLabelPfa(string[] mapperOutputs) + private protected override JToken PredictedLabelPfa(string[] mapperOutputs) { Contracts.CheckParam(Utils.Size(mapperOutputs) >= 1, nameof(mapperOutputs)); diff --git a/src/Microsoft.ML.Data/Scorers/ClusteringScorer.cs b/src/Microsoft.ML.Data/Scorers/ClusteringScorer.cs index ce1e447de7..1b5aa01c05 100644 --- a/src/Microsoft.ML.Data/Scorers/ClusteringScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/ClusteringScorer.cs @@ -44,7 +44,8 @@ private static VersionInfo GetVersionInfo() private const string RegistrationName = "ClusteringScore"; - public ClusteringScorer(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema) + [BestFriend] + internal ClusteringScorer(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema) : base(args, env, data, mapper, trainSchema, RegistrationName, MetadataUtils.Const.ScoreColumnKind.Clustering, MetadataUtils.Const.ScoreValueKind.Score, OutputTypeMatches, GetPredColType) { @@ -73,7 +74,7 @@ public static ClusteringScorer Create(IHostEnvironment env, ModelLoadContext ctx return h.Apply("Loading Model", ch => new ClusteringScorer(h, ctx, input)); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.SetVersionInfo(GetVersionInfo()); @@ -127,7 +128,7 @@ protected override Delegate GetPredictedLabelGetter(Row output, out Delegate sco return predFn; } - protected override JToken PredictedLabelPfa(string[] mapperOutputs) + private protected override JToken PredictedLabelPfa(string[] mapperOutputs) { Contracts.Assert(Utils.Size(mapperOutputs) == 1); return PfaUtils.Call("a.argmax", mapperOutputs[0]); diff --git a/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculationTransform.cs b/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculationTransform.cs index e8090fc614..ceea843335 100644 --- a/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculationTransform.cs +++ b/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculationTransform.cs @@ -64,7 +64,8 @@ public sealed class Arguments : ScorerArgumentsBase // REVIEW: the scorer currently ignores the 'suffix' argument from the base class. It should respect it. } - public static IDataScorerTransform Create(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema) + [BestFriend] + internal static IDataScorerTransform Create(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(data, nameof(data)); @@ -82,7 +83,8 @@ public static IDataScorerTransform Create(IHostEnvironment env, Arguments args, return scoredPipe; } - public static ISchemaBindableMapper Create(IHostEnvironment env, Arguments args, IPredictor predictor) + [BestFriend] + internal static ISchemaBindableMapper Create(IHostEnvironment env, Arguments args, IPredictor predictor) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(args, nameof(args)); @@ -116,7 +118,10 @@ public static IDataScorerTransform Create(IHostEnvironment env, Arguments args, return Create(env, args, data, boundMapper, null); } - public static ISchemaBindableMapper Create(IHostEnvironment env, ModelLoadContext ctx) + /// + /// Create method corresponding to . + /// + private static ISchemaBindableMapper Create(IHostEnvironment env, ModelLoadContext ctx) { return new BindableMapper(env, ctx); } diff --git a/src/Microsoft.ML.Data/Scorers/GenericScorer.cs b/src/Microsoft.ML.Data/Scorers/GenericScorer.cs index f301792603..8cf6eaaf7d 100644 --- a/src/Microsoft.ML.Data/Scorers/GenericScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/GenericScorer.cs @@ -138,7 +138,7 @@ private static VersionInfo GetVersionInfo() private const string RegistrationName = "GenericScore"; private readonly Bindings _bindings; - protected override BindingsBase GetBindings() => _bindings; + private protected override BindingsBase GetBindings() => _bindings; public override Schema OutputSchema { get; } @@ -149,7 +149,8 @@ private static VersionInfo GetVersionInfo() /// /// The entry point for creating a . /// - public GenericScorer(IHostEnvironment env, ScorerArgumentsBase args, IDataView data, + [BestFriend] + internal GenericScorer(IHostEnvironment env, ScorerArgumentsBase args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema) : base(env, data, RegistrationName, Contracts.CheckRef(mapper, nameof(mapper)).Bindable) { @@ -199,7 +200,7 @@ public static GenericScorer Create(IHostEnvironment env, ModelLoadContext ctx, I return h.Apply("Loading Model", ch => new GenericScorer(h, ctx, input)); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.SetVersionInfo(GetVersionInfo()); diff --git a/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs index a8ce0092c0..7de5b0cb5a 100644 --- a/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs @@ -79,7 +79,6 @@ public sealed class LabelNameBindableMapper : ISchemaBindableMapper, ICanSaveMod public VectorType Type => _type; bool ICanSavePfa.CanSavePfa => (_bindable as ICanSavePfa)?.CanSavePfa == true; bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => (_bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true; - public ISchemaBindableMapper InnerBindable => _bindable; private static VersionInfo GetVersionInfo() { @@ -138,8 +137,6 @@ private LabelNameBindableMapper(IHost host, ModelLoadContext ctx) ctx.LoadNonEmptyString() : MetadataUtils.Kinds.SlotNames; } - public ISchemaBindableMapper Clone(ISchemaBindableMapper inner) => new LabelNameBindableMapper(_host, inner, _type, _getter, _metadataKind, _canWrap); - private Delegate DecodeInit(object value) { _host.CheckDecode(value is VBuffer); @@ -148,7 +145,10 @@ private Delegate DecodeInit(object value) return buffGetter; } - public static ISchemaBindableMapper Create(IHostEnvironment env, ModelLoadContext ctx) + /// + /// Method corresponding to . + /// + private static ISchemaBindableMapper Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); var h = env.Register(LoaderSignature); @@ -211,7 +211,7 @@ bool IBindableCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, s return ((IBindableCanSaveOnnx)_bindable).SaveAsOnnx(ctx, schema, outputNames); } - public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema) + ISchemaBoundMapper ISchemaBindableMapper.Bind(IHostEnvironment env, RoleMappedSchema schema) { var innerBound = _bindable.Bind(env, schema); if (_canWrap != null && !_canWrap(innerBound, _type)) @@ -220,7 +220,7 @@ public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema) return Utils.MarshalInvoke(CreateBound, _type.ItemType.RawType, env, (ISchemaBoundRowMapper)innerBound, _type, _getter, _metadataKind, _canWrap); } - public static ISchemaBoundMapper CreateBound(IHostEnvironment env, ISchemaBoundRowMapper mapper, VectorType type, Delegate getter, + internal static ISchemaBoundMapper CreateBound(IHostEnvironment env, ISchemaBoundRowMapper mapper, VectorType type, Delegate getter, string metadataKind, Func canWrap) { Contracts.AssertValue(env); @@ -393,7 +393,7 @@ private static ISchemaBoundMapper WrapIfNeeded(IHostEnvironment env, ISchemaBoun /// from the model of a bindable mapper) /// Whether we can call with /// this mapper and expect it to succeed - public static bool CanWrap(ISchemaBoundMapper mapper, ColumnType labelNameType) + private static bool CanWrap(ISchemaBoundMapper mapper, ColumnType labelNameType) { Contracts.AssertValue(mapper); Contracts.AssertValue(labelNameType); @@ -414,7 +414,7 @@ public static bool CanWrap(ISchemaBoundMapper mapper, ColumnType labelNameType) return labelNameType.IsVector && labelNameType.VectorSize == scoreType.VectorSize; } - public static ISchemaBoundMapper WrapCore(IHostEnvironment env, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema) + private static ISchemaBoundMapper WrapCore(IHostEnvironment env, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema) { Contracts.AssertValue(env); env.AssertValue(mapper); @@ -436,7 +436,8 @@ public static ISchemaBoundMapper WrapCore(IHostEnvironment env, ISchemaBoundM return LabelNameBindableMapper.CreateBound(env, (ISchemaBoundRowMapper)mapper, type as VectorType, getter, MetadataUtils.Kinds.SlotNames, CanWrap); } - public MultiClassClassifierScorer(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema) + [BestFriend] + internal MultiClassClassifierScorer(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema) : base(args, env, data, WrapIfNeeded(env, mapper, trainSchema), trainSchema, RegistrationName, MetadataUtils.Const.ScoreColumnKind.MultiClassClassification, MetadataUtils.Const.ScoreValueKind.Score, OutputTypeMatches, GetPredColType) { @@ -454,7 +455,10 @@ private MultiClassClassifierScorer(IHost host, ModelLoadContext ctx, IDataView i // } - public static MultiClassClassifierScorer Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + /// + /// Corresponds to . + /// + private static MultiClassClassifierScorer Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) { Contracts.CheckValue(env, nameof(env)); var h = env.Register(RegistrationName); @@ -464,7 +468,7 @@ public static MultiClassClassifierScorer Create(IHostEnvironment env, ModelLoadC return h.Apply("Loading Model", ch => new MultiClassClassifierScorer(h, ctx, input)); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.SetVersionInfo(GetVersionInfo()); @@ -518,7 +522,7 @@ protected override Delegate GetPredictedLabelGetter(Row output, out Delegate sco return predFn; } - protected override JToken PredictedLabelPfa(string[] mapperOutputs) + private protected override JToken PredictedLabelPfa(string[] mapperOutputs) { Contracts.Assert(Utils.Size(mapperOutputs) == 1); return PfaUtils.Call("a.argmax", mapperOutputs[0]); diff --git a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs index e6f044c48f..0a869d250b 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs @@ -30,7 +30,8 @@ public abstract class ThresholdArgumentsBase : ScorerArgumentsBase public string ThresholdColumn = MetadataUtils.Const.ScoreValueKind.Score; } - protected sealed class BindingsImpl : BindingsBase + [BestFriend] + private protected sealed class BindingsImpl : BindingsBase { // Column index of the score column in Mapper's schema. public readonly int ScoreColumnIndex; @@ -272,15 +273,18 @@ public override Func GetActiveMapperColumns(bool[] active) } } - protected readonly BindingsImpl Bindings; - protected override BindingsBase GetBindings() => Bindings; + [BestFriend] + private protected readonly BindingsImpl Bindings; + [BestFriend] + private protected sealed override BindingsBase GetBindings() => Bindings; public override Schema OutputSchema { get; } bool ICanSavePfa.CanSavePfa => (Bindable as ICanSavePfa)?.CanSavePfa == true; bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => (Bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true; - protected PredictedLabelScorerBase(ScorerArgumentsBase args, IHostEnvironment env, IDataView data, + [BestFriend] + private protected PredictedLabelScorerBase(ScorerArgumentsBase args, IHostEnvironment env, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema, string registrationName, string scoreColKind, string scoreColName, Func outputTypeMatches, Func getPredColType) : base(env, data, registrationName, Contracts.CheckRef(mapper, nameof(mapper)).Bindable) @@ -314,7 +318,8 @@ protected PredictedLabelScorerBase(IHostEnvironment env, PredictedLabelScorerBas OutputSchema = Schema.Create(Bindings); } - protected PredictedLabelScorerBase(IHost host, ModelLoadContext ctx, IDataView input, + [BestFriend] + private protected PredictedLabelScorerBase(IHost host, ModelLoadContext ctx, IDataView input, Func outputTypeMatches, Func getPredColType) : base(host, ctx, input) { @@ -327,7 +332,7 @@ protected PredictedLabelScorerBase(IHost host, ModelLoadContext ctx, IDataView i OutputSchema = Schema.Create(Bindings); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { Host.AssertValue(ctx); Bindings.Save(ctx); @@ -360,7 +365,8 @@ void ISaveAsPfa.SaveAsPfa(BoundPfaContext ctx) ctx.DeclareVar(derivedName, predictedLabelExpression); } - protected abstract JToken PredictedLabelPfa(string[] mapperOutputs); + [BestFriend] + private protected abstract JToken PredictedLabelPfa(string[] mapperOutputs); void ISaveAsOnnx.SaveAsOnnx(OnnxContext ctx) => SaveAsOnnxCore(ctx); diff --git a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs index 25d31e13bc..08bf96af22 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs @@ -47,7 +47,8 @@ public abstract class PredictionTransformerBase : IPredictionTr protected const string DirModel = "Model"; protected const string DirTransSchema = "TrainSchema"; protected readonly IHost Host; - protected ISchemaBindableMapper BindableMapper; + [BestFriend] + private protected ISchemaBindableMapper BindableMapper; protected Schema TrainSchema; public bool IsRowToRowMapper => true; diff --git a/src/Microsoft.ML.Data/Scorers/QuantileRegressionScorer.cs b/src/Microsoft.ML.Data/Scorers/QuantileRegressionScorer.cs index 032df06b99..2dec3a0e4a 100644 --- a/src/Microsoft.ML.Data/Scorers/QuantileRegressionScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/QuantileRegressionScorer.cs @@ -18,7 +18,7 @@ namespace Microsoft.ML.Runtime.Data { - public static class QuantileRegressionScorerTransform + internal static class QuantileRegressionScorerTransform { public sealed class Arguments : ScorerArgumentsBase { @@ -26,12 +26,18 @@ public sealed class Arguments : ScorerArgumentsBase public string Quantiles = "0,0.25,0.5,0.75,1"; } - public static IDataScorerTransform Create(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema) + /// + /// Constructor corresponding to . + /// + private static IDataScorerTransform Create(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema) { return new GenericScorer(env, args, data, mapper, trainSchema); } - public static ISchemaBindableMapper Create(IHostEnvironment env, Arguments args, IPredictor predictor) + /// + /// Constructor corresponding to . + /// + private static ISchemaBindableMapper Create(IHostEnvironment env, Arguments args, IPredictor predictor) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(args, nameof(args)); diff --git a/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs b/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs index 54375cfc25..697a23309e 100644 --- a/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs +++ b/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs @@ -19,7 +19,8 @@ namespace Microsoft.ML.Runtime.Data /// public abstract class RowToRowScorerBase : RowToRowMapperTransformBase, IDataScorerTransform { - public abstract class BindingsBase : ScorerBindingsBase + [BestFriend] + private protected abstract class BindingsBase : ScorerBindingsBase { public readonly ISchemaBoundRowMapper RowMapper; @@ -30,16 +31,19 @@ protected BindingsBase(Schema schema, ISchemaBoundRowMapper mapper, string suffi } } - protected readonly ISchemaBindableMapper Bindable; + [BestFriend] + private protected readonly ISchemaBindableMapper Bindable; - protected RowToRowScorerBase(IHostEnvironment env, IDataView input, string registrationName, ISchemaBindableMapper bindable) + [BestFriend] + private protected RowToRowScorerBase(IHostEnvironment env, IDataView input, string registrationName, ISchemaBindableMapper bindable) : base(env, registrationName, input) { Contracts.AssertValue(bindable); Bindable = bindable; } - protected RowToRowScorerBase(IHost host, ModelLoadContext ctx, IDataView input) + [BestFriend] + private protected RowToRowScorerBase(IHost host, ModelLoadContext ctx, IDataView input) : base(host, input) { ctx.LoadModel(host, out Bindable, "SchemaBindableMapper"); @@ -56,7 +60,8 @@ public sealed override void Save(ModelSaveContext ctx) /// /// The main save method handles saving the _bindable. This should do everything else. /// - protected abstract void SaveCore(ModelSaveContext ctx); + [BestFriend] + private protected abstract void SaveCore(ModelSaveContext ctx); /// /// For the ITransformTemplate implementation. @@ -66,7 +71,8 @@ public sealed override void Save(ModelSaveContext ctx) /// /// Derived classes provide the specific bindings object. /// - protected abstract BindingsBase GetBindings(); + [BestFriend] + private protected abstract BindingsBase GetBindings(); /// /// Produces the set of active columns for the scorer (as a bool[] of length bindings.ColumnCount), @@ -296,11 +302,12 @@ public abstract class ScorerArgumentsBase } /// - /// Base bindings for a scorer based on an ISchemaBoundMapper. This assumes that input schema columns + /// Base bindings for a scorer based on an . This assumes that input schema columns /// are echoed, followed by zero or more derived columns, followed by the mapper generated columns. /// The names of the derived columns and mapper generated columns have an optional suffix appended. /// - public abstract class ScorerBindingsBase : ColumnBindingsBase + [BestFriend] + internal abstract class ScorerBindingsBase : ColumnBindingsBase { /// /// The schema bound mapper. diff --git a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs index 5da07c1c95..e6142e3b50 100644 --- a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs +++ b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs @@ -34,7 +34,7 @@ namespace Microsoft.ML.Runtime.Data /// /// This is a base class for wrapping s in an . /// - public abstract class SchemaBindablePredictorWrapperBase : ISchemaBindableMapper, ICanSaveModel, ICanSaveSummary, + internal abstract class SchemaBindablePredictorWrapperBase : ISchemaBindableMapper, ICanSaveModel, ICanSaveSummary, IBindableCanSavePfa, IBindableCanSaveOnnx { // The ctor guarantees that Predictor is non-null. It also ensures that either @@ -115,7 +115,7 @@ bool IBindableCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, s [BestFriend] private protected virtual bool SaveAsOnnxCore(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames) => false; - public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema) + ISchemaBoundMapper ISchemaBindableMapper.Bind(IHostEnvironment env, RoleMappedSchema schema) { Contracts.CheckValue(env, nameof(env)); @@ -142,7 +142,8 @@ public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema) } } - protected abstract ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema); + [BestFriend] + private protected abstract ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema); protected virtual Delegate GetPredictionGetter(Row input, int colSrc) { @@ -240,7 +241,7 @@ public Row GetRow(Row input, Func predicate) /// This class is a wrapper for all s except for quantile regression predictors, /// and calibrated binary classification predictors. /// - public sealed class SchemaBindablePredictorWrapper : SchemaBindablePredictorWrapperBase + internal sealed class SchemaBindablePredictorWrapper : SchemaBindablePredictorWrapperBase { public const string LoaderSignature = "SchemaBindableWrapper"; private static VersionInfo GetVersionInfo() @@ -316,7 +317,7 @@ private protected override bool SaveAsOnnxCore(OnnxContext ctx, RoleMappedSchema return mapper.SaveAsOnnx(ctx, outputNames, ctx.GetVariableName(schema.Feature.Name)); } - protected override ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema) + private protected override ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema) { var outputSchema = Schema.Create(new ScoreMapperSchema(ScoreType, _scoreColumnKind)); return new SingleValueRowMapper(schema, this, outputSchema); @@ -352,7 +353,7 @@ private static string GetScoreColumnKind(IPredictor predictor) /// This is an wrapper for calibrated binary classification predictors. /// They need a separate wrapper because they return two values instead of one: the raw score and the probability. /// - public sealed class SchemaBindableBinaryPredictorWrapper : SchemaBindablePredictorWrapperBase + internal sealed class SchemaBindableBinaryPredictorWrapper : SchemaBindablePredictorWrapperBase { public const string LoaderSignature = "BinarySchemaBindable"; private static VersionInfo GetVersionInfo() @@ -450,7 +451,7 @@ private void CheckValid(out IValueMapperDist distMapper) "Invalid probability type for the IValueMapperDist"); } - protected override ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema) + private protected override ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema) { if (Predictor.PredictionKind != PredictionKind.BinaryClassification) ch.Warning("Scoring predictor of kind '{0}' as '{1}'.", Predictor.PredictionKind, PredictionKind.BinaryClassification); @@ -577,9 +578,10 @@ public Row GetRow(Row input, Func predicate) /// /// This is an wrapper for quantile regression predictors. They need a separate - /// wrapper because they need the quantiles to create the ISchemaBound. + /// wrapper because they need the quantiles to create the . /// - public sealed class SchemaBindableQuantileRegressionPredictor : SchemaBindablePredictorWrapperBase + [BestFriend] + internal sealed class SchemaBindableQuantileRegressionPredictor : SchemaBindablePredictorWrapperBase { public const string LoaderSignature = "QuantileSchemaBindable"; private static VersionInfo GetVersionInfo() @@ -650,7 +652,7 @@ public static SchemaBindableQuantileRegressionPredictor Create(IHostEnvironment return new SchemaBindableQuantileRegressionPredictor(env, ctx); } - protected override ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema) + private protected override ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema) { return new SingleValueRowMapper(schema, this, Schema.Create(new SchemaImpl(ScoreType, _quantiles))); } diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs index 15957c225c..28bca6c58f 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs @@ -316,7 +316,7 @@ public static void LoadModel(ModelLoadContext ctx, int cv, out bool useLog, out /// It tracks min, max, number of non-sparse values (vCount) and number of ProcessValue() calls (trainCount). /// NaNs are ignored when updating min and max. /// - public sealed class MinMaxSngAggregator : IColumnAggregator> + internal sealed class MinMaxSngAggregator : IColumnAggregator> { private readonly TFloat[] _min; private readonly TFloat[] _max; diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeUtils.cs b/src/Microsoft.ML.Data/Transforms/NormalizeUtils.cs index 0f1ea0ef62..b26111b03c 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeUtils.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeUtils.cs @@ -40,7 +40,8 @@ internal interface IColumnFunctionBuilder /// /// Interface to define an aggregate function over values /// - public interface IColumnAggregator + [BestFriend] + internal interface IColumnAggregator { /// /// Updates the aggregate function with a value @@ -68,7 +69,7 @@ internal interface IColumnFunction : ICanSaveModel NormalizingTransformer.NormalizerModelParametersBase GetNormalizerModelParams(); } - public static class NormalizeUtils + internal static class NormalizeUtils { /// /// Returns whether the feature column in the schema is indicated to be normalized. If the features column is not diff --git a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransformer.cs b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransformer.cs index e8ec5f4e53..dbc4e4d685 100644 --- a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransformer.cs @@ -20,7 +20,8 @@ namespace Microsoft.ML.Transforms { - public static class ScoringTransformer + [BestFriend] + internal static class ScoringTransformer { public sealed class Arguments : TransformInputBase { diff --git a/src/Microsoft.ML.Data/Utilities/ModelFileUtils.cs b/src/Microsoft.ML.Data/Utilities/ModelFileUtils.cs index e7a61c42d6..0b962f435e 100644 --- a/src/Microsoft.ML.Data/Utilities/ModelFileUtils.cs +++ b/src/Microsoft.ML.Data/Utilities/ModelFileUtils.cs @@ -20,7 +20,8 @@ namespace Microsoft.ML.Runtime.Model /// /// This class provides utilities for loading components from the model file generated by MAML commands. /// - public static class ModelFileUtils + [BestFriend] + internal static class ModelFileUtils { public const string DirPredictor = "Predictor"; public const string DirDataLoaderModel = "DataLoaderModel"; diff --git a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs index d6717d7224..97ddf9ef43 100644 --- a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs +++ b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs @@ -218,7 +218,7 @@ protected override void SaveCore(ModelSaveContext ctx) ctx.SaveModel(Combiner, "Combiner"); } - public override ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema) + private protected override ISchemaBoundMapper BindCore(IHostEnvironment env, RoleMappedSchema schema) { return new Bound(this, schema); } @@ -558,7 +558,9 @@ public static SchemaBindablePipelineEnsembleBase Create(IHostEnvironment env, Mo } } - public abstract ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema); + ISchemaBoundMapper ISchemaBindableMapper.Bind(IHostEnvironment env, RoleMappedSchema schema) => BindCore(env, schema); + + private protected abstract ISchemaBoundMapper BindCore(IHostEnvironment env, RoleMappedSchema schema); void ICanSaveSummary.SaveSummary(TextWriter writer, RoleMappedSchema schema) { diff --git a/src/Microsoft.ML.FastTree/TreeEnsemble/RegressionTree.cs b/src/Microsoft.ML.FastTree/TreeEnsemble/RegressionTree.cs index 834a4188f1..7cd371ba91 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsemble/RegressionTree.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsemble/RegressionTree.cs @@ -1119,7 +1119,7 @@ public void RemapFeatures(int[] oldToNewFeatures) /// A map of feature index (in the features array) /// to the ID as it will be written in the file. This instance should be /// used for all - public void ToTreeEnsembleFormat(StringBuilder sbEvaluator, StringBuilder sbInput, FeaturesToContentMap featureContents, + internal void ToTreeEnsembleFormat(StringBuilder sbEvaluator, StringBuilder sbInput, FeaturesToContentMap featureContents, ref int evaluatorCounter, Dictionary featureToId) { Contracts.AssertValue(sbEvaluator); @@ -1259,7 +1259,7 @@ private void ToTreeEnsembleFormatForCategoricalSplit(StringBuilder sbEvaluator, } // prints the tree out as a string (in old Bing format used by LambdaMART and AdIndex) - public string ToOldIni(FeatureNameCollection featureNames) + internal string ToOldIni(FeatureNameCollection featureNames) { // print the root node StringBuilder output = new StringBuilder(); diff --git a/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsemble.cs b/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsemble.cs index 08ddb00a32..b1a21f9541 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsemble.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsemble.cs @@ -107,7 +107,7 @@ public void RemapFeatures(int[] oldToNewFeatures) /// /// returns the ensemble in the production TreeEnsemble format /// - public string ToTreeEnsembleIni(FeaturesToContentMap fmap, + internal string ToTreeEnsembleIni(FeaturesToContentMap fmap, string trainingParams, bool appendFeatureGain, bool includeZeroGainFeatures = true) { StringBuilder sbEvaluator = new StringBuilder(); @@ -306,7 +306,7 @@ public void GetOutputs(Dataset dataset, double[] outputs, int prefix) Parallel.Invoke(new ParallelOptions { MaxDegreeOfParallelism = BlockingThreadPool.NumThreads }, actions); } - public string ToGainSummary(FeaturesToContentMap fmap, Dictionary featureToID, int prefix, bool includeZeroGainFeatures, bool normalize, int startingCommentNumber) + internal string ToGainSummary(FeaturesToContentMap fmap, Dictionary featureToID, int prefix, bool includeZeroGainFeatures, bool normalize, int startingCommentNumber) { if (_trees.Count == 0) return string.Empty; @@ -397,7 +397,7 @@ public FeatureToGainMap(IList trees, bool normalize) /// A class that given either a /// provides a mechanism for getting the corresponding input INI content for the features. /// - public sealed class FeaturesToContentMap + internal sealed class FeaturesToContentMap { private readonly VBuffer> _content; private readonly VBuffer> _names; diff --git a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs index b6ac50bb70..27951364fc 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs @@ -525,7 +525,7 @@ private void GetPathSlotNames(int col, ref VBuffer> dst) dst = editor.Commit(); } - public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema) + ISchemaBoundMapper ISchemaBindableMapper.Bind(IHostEnvironment env, RoleMappedSchema schema) { Contracts.AssertValue(env); env.AssertValue(schema); @@ -654,7 +654,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV vm.InputType.VectorSize, data.Schema.Feature.Type.VectorSize); } - var bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor); + ISchemaBindableMapper bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor); var bound = bindable.Bind(env, data.Schema); xf = new GenericScorer(env, scorerArgs, input, bound, data.Schema); } @@ -718,7 +718,7 @@ public static IDataTransform CreateForEntryPoint(IHostEnvironment env, Arguments vm.InputType.VectorSize, data.Schema.Feature.Type.VectorSize); } - var bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor); + ISchemaBindableMapper bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor); var bound = bindable.Bind(env, data.Schema); return new GenericScorer(env, scorerArgs, data.Data, bound, data.Schema); } diff --git a/src/Microsoft.ML.PCA/PcaTrainer.cs b/src/Microsoft.ML.PCA/PcaTrainer.cs index fee4a25d07..3cbb610e97 100644 --- a/src/Microsoft.ML.PCA/PcaTrainer.cs +++ b/src/Microsoft.ML.PCA/PcaTrainer.cs @@ -530,7 +530,7 @@ void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) } } - public IDataView GetSummaryDataView(RoleMappedSchema schema) + IDataView ICanGetSummaryAsIDataView.GetSummaryDataView(RoleMappedSchema schema) { var bldr = new ArrayDataViewBuilder(Host); diff --git a/src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs b/src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs index f33bc405b8..fb78fdcd33 100644 --- a/src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs +++ b/src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs @@ -261,7 +261,7 @@ private float Score(int columnIndex, int rowIndex) /// Create a row mapper based on regression scorer. Because matrix factorization predictor maps a tuple of a row ID (u) and a column ID (v) /// to the expected numerical value at the u-th row and the v-th column in the considered matrix, it is essentially a regressor. /// - public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema) + ISchemaBoundMapper ISchemaBindableMapper.Bind(IHostEnvironment env, RoleMappedSchema schema) { Contracts.AssertValue(env); env.AssertValue(schema); diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs index 636f49fb60..a417935210 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs @@ -173,7 +173,7 @@ internal float CalculateResponse(ValueGetter>[] getters, VBuffer< return modelResponse; } - public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema) + ISchemaBoundMapper ISchemaBindableMapper.Bind(IHostEnvironment env, RoleMappedSchema schema) => new FieldAwareFactorizationMachineScalarRowMapper(env, schema, Schema.Create(new BinaryClassifierSchema()), this); internal void CopyLinearWeightsTo(float[] linearWeights) diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs index 6085504bb6..0be38f3547 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs @@ -970,7 +970,7 @@ private void SaveLabelNames(ModelSaveContext ctx, BinaryWriter writer) } } - public IDataView GetSummaryDataView(RoleMappedSchema schema) + IDataView ICanGetSummaryAsIDataView.GetSummaryDataView(RoleMappedSchema schema) { var bldr = new ArrayDataViewBuilder(Host); diff --git a/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs b/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs index 0a75f3b948..47c993dcfa 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs @@ -6,6 +6,7 @@ using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Internal.CpuMath; +using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Learners; using Microsoft.ML.Runtime.Model; @@ -324,7 +325,7 @@ private List GetUnorderedCoefficientStatistics(LinearBina /// /// Gets the coefficient statistics as an object. /// - public CoefficientStatistics[] GetCoefficientStatistics(LinearBinaryModelParameters parent, RoleMappedSchema schema, int paramCountCap) + internal CoefficientStatistics[] GetCoefficientStatistics(LinearBinaryModelParameters parent, RoleMappedSchema schema, int paramCountCap) { Contracts.AssertValue(_env); _env.CheckValue(parent, nameof(parent)); @@ -345,7 +346,7 @@ public CoefficientStatistics[] GetCoefficientStatistics(LinearBinaryModelParamet return order.Prepend(new[] { new CoefficientStatistics("(Bias)", bias, stdError, zScore, pValue) }).ToArray(); } - public void SaveText(TextWriter writer, LinearBinaryModelParameters parent, RoleMappedSchema schema, int paramCountCap) + internal void SaveText(TextWriter writer, LinearBinaryModelParameters parent, RoleMappedSchema schema, int paramCountCap) { Contracts.AssertValue(_env); _env.CheckValue(writer, nameof(writer)); @@ -383,7 +384,10 @@ public void SaveText(TextWriter writer, LinearBinaryModelParameters parent, Role writer.WriteLine("Significance codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1"); } - public void SaveSummaryInKeyValuePairs(LinearBinaryModelParameters parent, + /// + /// Support method for linear models and . + /// + internal void SaveSummaryInKeyValuePairs(LinearBinaryModelParameters parent, RoleMappedSchema schema, int paramCountCap, List> resultCollection) { Contracts.AssertValue(_env);