diff --git a/src/Microsoft.ML.Core/Data/IRowToRowMapper.cs b/src/Microsoft.ML.Core/Data/IRowToRowMapper.cs
new file mode 100644
index 0000000000..29de619946
--- /dev/null
+++ b/src/Microsoft.ML.Core/Data/IRowToRowMapper.cs
@@ -0,0 +1,54 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Data;
+using System;
+
+namespace Microsoft.ML.Runtime.Data
+{
+ ///
+ /// This interface maps an input to an output . Typically, the output contains
+ /// both the input columns and new columns added by the implementing class, although some implementations may
+ /// return a subset of the input columns.
+ /// This interface is similar to , except it does not have any input role mappings,
+ /// so to rebind, the same input column names must be used.
+ /// Implementations of this interface are typically created over defined input .
+ ///
+ public interface IRowToRowMapper
+ {
+ ///
+ /// Mappers are defined as accepting inputs with this very specific schema.
+ ///
+ Schema InputSchema { get; }
+
+ ///
+ /// Gets an instance of which describes the columns' names and types in the output generated by this mapper.
+ ///
+ Schema OutputSchema { get; }
+
+ ///
+ /// Given a predicate specifying which columns are needed, return a predicate indicating which input columns are
+ /// needed. The domain of the function is defined over the indices of the columns of
+ /// for .
+ ///
+ Func GetDependencies(Func predicate);
+
+ ///
+ /// Get an with the indicated active columns, based on the input .
+ /// The active columns are those for which returns true. Getting values on inactive
+ /// columns of the returned row will throw. Null predicates are disallowed.
+ ///
+ /// The of should be the same object as
+ /// . Implementors of this method should throw if that is not the case. Conversely,
+ /// the returned value must have the same schema as .
+ ///
+ /// This method creates a live connection between the input and the output . In particular, when the getters of the output are invoked, they invoke the
+ /// getters of the input row and base the output values on the current values of the input .
+ /// The output values are re-computed when requested through the getters. Also, the returned
+ /// will dispose when it is disposed.
+ ///
+ Row GetRow(Row input, Func active);
+ }
+}
diff --git a/src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs b/src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs
index caa7b15330..ff005d157d 100644
--- a/src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs
+++ b/src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs
@@ -3,7 +3,6 @@
// See the LICENSE file in the project root for more information.
using Microsoft.ML.Data;
-using System;
using System.Collections.Generic;
namespace Microsoft.ML.Runtime.Data
@@ -21,7 +20,8 @@ namespace Microsoft.ML.Runtime.Data
/// for the output schema of the . In case the interface is implemented,
/// the SimpleRow class can be used in the method.
///
- public interface ISchemaBindableMapper
+ [BestFriend]
+ internal interface ISchemaBindableMapper
{
ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema);
}
@@ -30,7 +30,8 @@ public interface ISchemaBindableMapper
/// This interface is used to map a schema from input columns to output columns. The should keep track
/// of the input columns that are needed for the mapping.
///
- public interface ISchemaBoundMapper
+ [BestFriend]
+ internal interface ISchemaBoundMapper
{
///
/// The that was passed to the in the binding process.
@@ -56,7 +57,8 @@ public interface ISchemaBoundMapper
///
/// This interface combines with .
///
- public interface ISchemaBoundRowMapper : ISchemaBoundMapper, IRowToRowMapper
+ [BestFriend]
+ internal interface ISchemaBoundRowMapper : ISchemaBoundMapper, IRowToRowMapper
{
///
/// There are two schemas from and .
@@ -64,49 +66,4 @@ public interface ISchemaBoundRowMapper : ISchemaBoundMapper, IRowToRowMapper
///
new Schema OutputSchema { get; }
}
-
- ///
- /// This interface maps an input to an output . Typically, the output contains
- /// both the input columns and new columns added by the implementing class, although some implementations may
- /// return a subset of the input columns.
- /// This interface is similar to , except it does not have any input role mappings,
- /// so to rebind, the same input column names must be used.
- /// Implementations of this interface are typically created over defined input .
- ///
- public interface IRowToRowMapper
- {
- ///
- /// Mappers are defined as accepting inputs with this very specific schema.
- ///
- Schema InputSchema { get; }
-
- ///
- /// Gets an instance of which describes the columns' names and types in the output generated by this mapper.
- ///
- Schema OutputSchema { get; }
-
- ///
- /// Given a predicate specifying which columns are needed, return a predicate indicating which input columns are
- /// needed. The domain of the function is defined over the indices of the columns of
- /// for .
- ///
- Func GetDependencies(Func predicate);
-
- ///
- /// Get an with the indicated active columns, based on the input .
- /// The active columns are those for which returns true. Getting values on inactive
- /// columns of the returned row will throw. Null predicates are disallowed.
- ///
- /// The of should be the same object as
- /// . Implementors of this method should throw if that is not the case. Conversely,
- /// the returned value must have the same schema as .
- ///
- /// This method creates a live connection between the input and the output . In particular, when the getters of the output are invoked, they invoke the
- /// getters of the input row and base the output values on the current values of the input .
- /// The output values are re-computed when requested through the getters. Also, the returned
- /// will dispose when it is disposed.
- ///
- Row GetRow(Row input, Func active);
- }
}
diff --git a/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs b/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs
index 94ac81af9c..9a3f02cc24 100644
--- a/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs
+++ b/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs
@@ -100,7 +100,8 @@ public static ColumnInfo CreateFromIndex(Schema schema, int index)
///
///
///
- public sealed class RoleMappedSchema
+ [BestFriend]
+ internal sealed class RoleMappedSchema
{
private const string FeatureString = "Feature";
private const string LabelString = "Label";
diff --git a/src/Microsoft.ML.Data/Commands/SavePredictorCommand.cs b/src/Microsoft.ML.Data/Commands/SavePredictorCommand.cs
index 633f871dd0..c537cbd3ab 100644
--- a/src/Microsoft.ML.Data/Commands/SavePredictorCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/SavePredictorCommand.cs
@@ -114,7 +114,8 @@ private Stream CreateStrm(IFileHandle file)
}
}
- public static class SavePredictorUtils
+ [BestFriend]
+ internal static class SavePredictorUtils
{
public static void SavePredictor(IHostEnvironment env, Stream modelStream, Stream binaryModelStream = null, Stream summaryModelStream = null,
Stream textModelStream = null, Stream iniModelStream = null, Stream codeModelStream = null)
diff --git a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs
index 89c7572b63..fe289fa4f7 100644
--- a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs
@@ -34,7 +34,8 @@ public interface IDataScorerTransform : IDataTransform, ITransformTemplate
/// This parameter holds a snapshot of the role mapped training schema as
/// it existed at the point when was trained, or null if it not
/// available for some reason
- public delegate void SignatureDataScorer(IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema);
+ [BestFriend]
+ internal delegate void SignatureDataScorer(IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema);
public delegate void SignatureBindableMapper(IPredictor predictor);
diff --git a/src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs b/src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs
index 44356059e1..8364b17060 100644
--- a/src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs
+++ b/src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs
@@ -16,7 +16,8 @@
namespace Microsoft.ML.Runtime.Internal.Internallearn
{
- public abstract class FeatureNameCollection : IEnumerable
+ [BestFriend]
+ internal abstract class FeatureNameCollection : IEnumerable
{
private sealed class FeatureNameCollectionSchema : ISchema
{
diff --git a/src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs b/src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs
index 312603b949..405dd31296 100644
--- a/src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs
+++ b/src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs
@@ -143,7 +143,8 @@ internal interface ICanGetSummaryAsIRow
Row GetStatsIRowOrNull(RoleMappedSchema schema);
}
- public interface ICanGetSummaryAsIDataView
+ [BestFriend]
+ internal interface ICanGetSummaryAsIDataView
{
IDataView GetSummaryDataView(RoleMappedSchema schema);
}
diff --git a/src/Microsoft.ML.Data/Dirty/PredictorUtils.cs b/src/Microsoft.ML.Data/Dirty/PredictorUtils.cs
index bcd5492f59..46e6a859f5 100644
--- a/src/Microsoft.ML.Data/Dirty/PredictorUtils.cs
+++ b/src/Microsoft.ML.Data/Dirty/PredictorUtils.cs
@@ -9,10 +9,11 @@
namespace Microsoft.ML.Runtime.Internal.Internallearn
{
- public static class PredictorUtils
+ [BestFriend]
+ internal static class PredictorUtils
{
///
- /// Save the model summary
+ /// Save the model summary.
///
public static void SaveSummary(IChannel ch, IPredictor predictor, RoleMappedSchema schema, TextWriter writer)
{
diff --git a/src/Microsoft.ML.Data/EntryPoints/SummarizePredictor.cs b/src/Microsoft.ML.Data/EntryPoints/SummarizePredictor.cs
index 49e5acdca5..12cee16b3b 100644
--- a/src/Microsoft.ML.Data/EntryPoints/SummarizePredictor.cs
+++ b/src/Microsoft.ML.Data/EntryPoints/SummarizePredictor.cs
@@ -43,7 +43,8 @@ public static CommonOutputs.SummaryOutput Summarize(IHostEnvironment env, Summar
return output;
}
- public static IDataView GetSummaryAndStats(IHostEnvironment env, IPredictor predictor, RoleMappedSchema schema, out IDataView stats)
+ [BestFriend]
+ internal static IDataView GetSummaryAndStats(IHostEnvironment env, IPredictor predictor, RoleMappedSchema schema, out IDataView stats)
{
var calibrated = predictor as CalibratedPredictorBase;
while (calibrated != null)
diff --git a/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs
index a42d449aef..905bdaaa13 100644
--- a/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs
@@ -91,7 +91,7 @@ public AnomalyDetectionEvaluator(IHostEnvironment env, Arguments args)
_aucCount = args.MaxAucExamples;
}
- protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
+ private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
{
var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score);
var t = score.Type;
@@ -103,7 +103,7 @@ protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
throw Host.Except("Label column '{0}' has type '{1}' but must be R4 or a 2-value key", schema.Label.Name, t).MarkSensitive(MessageSensitivity.Schema);
}
- protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName)
+ private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName)
{
return new Aggregator(Host, _aucCount, _numTopResults, _k, _p, _streaming, schema.Name == null ? -1 : schema.Name.Index, stratName);
}
@@ -124,7 +124,7 @@ public override IEnumerable GetOverallMetricColumns()
yield return new MetricColumn("ThreshAtNumPos", OverallMetrics.ThreshAtNumPos, MetricColumn.Objective.Info, canBeWeighted: false);
}
- protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries,
+ private protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries,
out Action, Aggregator> addAgg, out Func> consolidate)
{
var stratCol = new List();
@@ -498,7 +498,7 @@ private void FinishOtherMetrics()
}
}
- public override void InitializeNextPass(Row row, RoleMappedSchema schema)
+ internal override void InitializeNextPass(Row row, RoleMappedSchema schema)
{
Host.Assert(!_streaming && PassNum < 2 || PassNum < 1);
Host.AssertValue(schema.Label);
@@ -621,7 +621,7 @@ public AnomalyDetectionMamlEvaluator(IHostEnvironment env, Arguments args)
_evaluator = new AnomalyDetectionEvaluator(Host, evalArgs);
}
- protected override void PrintFoldResultsCore(IChannel ch, Dictionary metrics)
+ private protected override void PrintFoldResultsCore(IChannel ch, Dictionary metrics)
{
IDataView top;
if (!metrics.TryGetValue(AnomalyDetectionEvaluator.TopKResults, out top))
@@ -732,7 +732,7 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary GetPerInstanceColumnsToSave(RoleMappedSchema schema)
+ private protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema)
{
Host.CheckValue(schema, nameof(schema));
Host.CheckValue(schema.Label, nameof(schema), "Data must contain a label column");
diff --git a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs
index 2af7447f91..4468ba000d 100644
--- a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs
@@ -123,7 +123,7 @@ public BinaryClassifierEvaluator(IHostEnvironment env, Arguments args)
_auPrcCount = args.NumAuPrcExamples;
}
- protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
+ private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
{
var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score);
var host = Host.SchemaSensitive();
@@ -136,7 +136,7 @@ protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
throw host.SchemaSensitive().Except("Label column '{0}' has type '{1}' but must be R4, R8, BL or a 2-value key", schema.Label.Name, t);
}
- protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema)
+ private protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema)
{
var prob = schema.GetColumns(MetadataUtils.Const.ScoreValueKind.Probability);
var host = Host.SchemaSensitive();
@@ -155,7 +155,7 @@ protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema)
}
// Add also the probability column.
- protected override Func GetActiveColsCore(RoleMappedSchema schema)
+ private protected override Func GetActiveColsCore(RoleMappedSchema schema)
{
var pred = base.GetActiveColsCore(schema);
var prob = schema.GetColumns(MetadataUtils.Const.ScoreValueKind.Probability);
@@ -163,7 +163,7 @@ protected override Func GetActiveColsCore(RoleMappedSchema schema)
return i => Utils.Size(prob) > 0 && i == prob[0].Index || pred(i);
}
- protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName)
+ private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName)
{
var classNames = GetClassNames(schema);
return new Aggregator(Host, classNames, schema.Weight != null, _aucCount, _auPrcCount, _threshold, _useRaw, _prCount, stratName);
@@ -215,7 +215,7 @@ public override IEnumerable GetOverallMetricColumns()
yield return new MetricColumn("AUPRC", AuPrc);
}
- protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries,
+ private protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries,
out Action, Aggregator> addAgg, out Func> consolidate)
{
var stratCol = new List();
@@ -609,7 +609,7 @@ public Aggregator(IHostEnvironment env, ReadOnlyMemory[] classNames, bool
}
}
- public override void InitializeNextPass(Row row, RoleMappedSchema schema)
+ internal override void InitializeNextPass(Row row, RoleMappedSchema schema)
{
Host.AssertValue(schema.Label);
Host.Assert(PassNum < 1);
@@ -1172,7 +1172,7 @@ public BinaryClassifierMamlEvaluator(IHostEnvironment env, Arguments args)
_evaluator = new BinaryClassifierEvaluator(Host, evalArgs);
}
- protected override IEnumerable> GetInputColumnRolesCore(RoleMappedSchema schema)
+ private protected override IEnumerable> GetInputColumnRolesCore(RoleMappedSchema schema)
{
var cols = base.GetInputColumnRolesCore(schema);
@@ -1187,7 +1187,7 @@ public BinaryClassifierMamlEvaluator(IHostEnvironment env, Arguments args)
return cols;
}
- protected override void PrintFoldResultsCore(IChannel ch, Dictionary metrics)
+ private protected override void PrintFoldResultsCore(IChannel ch, Dictionary metrics)
{
ch.AssertValue(metrics);
@@ -1240,12 +1240,12 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary[] metrics)
+ private protected override void PrintAdditionalMetricsCore(IChannel ch, Dictionary[] metrics)
{
ch.AssertNonEmpty(metrics);
@@ -1479,7 +1479,7 @@ private void SavePrPlots(List prList)
return avgPoints;
}
#endif
- protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema)
+ private protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema)
{
Host.CheckValue(schema, nameof(schema));
Host.CheckParam(schema.Label != null, nameof(schema), "Schema must contain a label column");
diff --git a/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs
index bd8fb347f7..fa60665447 100644
--- a/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs
@@ -94,7 +94,7 @@ public ClusteringMetrics Evaluate(IDataView data, string score, string label = n
return result;
}
- protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
+ private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
{
ColumnType type;
if (schema.Label != null && (type = schema.Label.Type) != NumberType.Float && type.KeyCount == 0)
@@ -110,7 +110,7 @@ protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
throw Host.Except("Scores column '{0}' type must be a float vector of known size", score.Name);
}
- protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema)
+ private protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema)
{
if (_calculateDbi)
{
@@ -124,7 +124,7 @@ protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema)
}
}
- protected override Func GetActiveColsCore(RoleMappedSchema schema)
+ private protected override Func GetActiveColsCore(RoleMappedSchema schema)
{
var pred = base.GetActiveColsCore(schema);
// We also need the features column for dbi calculation.
@@ -132,7 +132,7 @@ protected override Func GetActiveColsCore(RoleMappedSchema schema)
return i => _calculateDbi && i == schema.Feature.Index || pred(i);
}
- protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName)
+ private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName)
{
Host.AssertValue(schema);
Host.Assert(!_calculateDbi || (schema.Feature != null && schema.Feature.Type.IsKnownSizeVector));
@@ -156,7 +156,7 @@ public override IEnumerable GetOverallMetricColumns()
yield return new MetricColumn("DBI", Dbi, MetricColumn.Objective.Minimize);
}
- protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries,
+ private protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries,
out Action, Aggregator> addAgg, out Func> consolidate)
{
var stratCol = new List();
@@ -484,7 +484,7 @@ private void ProcessRowSecondPass()
WeightedCounters.UpdateSecondPass(in _features, _indicesArr);
}
- public override void InitializeNextPass(Row row, RoleMappedSchema schema)
+ internal override void InitializeNextPass(Row row, RoleMappedSchema schema)
{
AssertValid(assertGetters: false);
@@ -796,7 +796,7 @@ public ClusteringMamlEvaluator(IHostEnvironment env, Arguments args)
_evaluator = new ClusteringEvaluator(Host, evalArgs);
}
- protected override IEnumerable> GetInputColumnRolesCore(RoleMappedSchema schema)
+ private protected override IEnumerable> GetInputColumnRolesCore(RoleMappedSchema schema)
{
foreach (var col in base.GetInputColumnRolesCore(schema))
{
@@ -816,7 +816,7 @@ public ClusteringMamlEvaluator(IHostEnvironment env, Arguments args)
}
// Clustering evaluator adds three per-instance columns: "ClusterId", "Top clusters" and "Top cluster scores".
- protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema)
+ private protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema)
{
Host.CheckValue(schema, nameof(schema));
@@ -830,7 +830,7 @@ protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSch
yield return ClusteringPerInstanceEvaluator.SortedClusterScores;
}
- protected override IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMappedSchema schema)
+ private protected override IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMappedSchema schema)
{
// Wrap with a DropSlots transform to pick only the first _numTopClusters slots.
if (perInst.Schema.TryGetColumnIndex(ClusteringPerInstanceEvaluator.SortedClusters, out int index))
diff --git a/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs b/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs
index ab1330c31b..287d882645 100644
--- a/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs
+++ b/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs
@@ -22,7 +22,8 @@ public abstract partial class EvaluatorBase : IEvaluator
{
protected readonly IHost Host;
- protected EvaluatorBase(IHostEnvironment env, string registrationName)
+ [BestFriend]
+ private protected EvaluatorBase(IHostEnvironment env, string registrationName)
{
Contracts.CheckValue(env, nameof(env));
Host = env.Register(registrationName);
@@ -44,7 +45,8 @@ Dictionary IEvaluator.Evaluate(RoleMappedData data)
/// Checks the column types of the evaluator's input columns. The base class implementation checks only the type
/// of the weight column, and all other columns should be checked by the deriving classes in .
///
- protected void CheckColumnTypes(RoleMappedSchema schema)
+ [BestFriend]
+ private protected void CheckColumnTypes(RoleMappedSchema schema)
{
// Check the weight column type.
if (schema.Weight != null)
@@ -60,13 +62,15 @@ protected void CheckColumnTypes(RoleMappedSchema schema)
/// Access the label column with the property, and the score column with the
/// or methods.
///
- protected abstract void CheckScoreAndLabelTypes(RoleMappedSchema schema);
+ [BestFriend]
+ private protected abstract void CheckScoreAndLabelTypes(RoleMappedSchema schema);
///
/// Check the types of any other columns needed by the evaluator. Only override if the evaluator uses
/// columns other than label, score and weight.
///
- protected virtual void CheckCustomColumnTypesCore(RoleMappedSchema schema)
+ [BestFriend]
+ private protected virtual void CheckCustomColumnTypesCore(RoleMappedSchema schema)
{
}
@@ -84,7 +88,8 @@ private Func GetActiveCols(RoleMappedSchema schema)
/// and the stratification columns.
/// Override if other input columns need to be activated.
///
- protected virtual Func GetActiveColsCore(RoleMappedSchema schema)
+ [BestFriend]
+ private protected virtual Func GetActiveColsCore(RoleMappedSchema schema)
{
var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score);
var label = schema.Label == null ? -1 : schema.Label.Index;
@@ -116,7 +121,8 @@ private AggregatorDictionaryBase[] GetAggregatorDictionaries(RoleMappedSchema sc
return list.ToArray();
}
- protected abstract TAgg GetAggregatorCore(RoleMappedSchema schema, string stratName);
+ [BestFriend]
+ private protected abstract TAgg GetAggregatorCore(RoleMappedSchema schema, string stratName);
// This method does as many passes over the data as needed by the evaluator, and computes the metrics, outputting the
// results in a dictionary from the metric kind (overal/per-fold/confusion matrix/PR-curves etc.), to a data view containing
@@ -192,10 +198,12 @@ private Dictionary ProcessData(IDataView data, RoleMappedSche
/// is called after has been called on all the aggregators, and it returns
/// the dictionary of metric data views.
///
- protected abstract void GetAggregatorConsolidationFuncs(TAgg aggregator, AggregatorDictionaryBase[] dictionaries,
+ [BestFriend]
+ private protected abstract void GetAggregatorConsolidationFuncs(TAgg aggregator, AggregatorDictionaryBase[] dictionaries,
out Action, TAgg> addAgg, out Func> consolidate);
- protected ValueGetter>> GetKeyValueGetter(AggregatorDictionaryBase[] dictionaries)
+ [BestFriend]
+ private protected ValueGetter>> GetKeyValueGetter(AggregatorDictionaryBase[] dictionaries)
{
if (Utils.Size(dictionaries) == 0)
return null;
@@ -236,7 +244,8 @@ public abstract class AggregatorBase
protected int PassNum;
- protected AggregatorBase(IHostEnvironment env, string stratName)
+ [BestFriend]
+ private protected AggregatorBase(IHostEnvironment env, string stratName)
{
Contracts.AssertValue(env);
Host = env.Register("Aggregator");
@@ -256,7 +265,8 @@ public bool Start()
///
/// This method should get the getters of the new IRow that are needed for the next pass.
///
- public abstract void InitializeNextPass(Row row, RoleMappedSchema schema);
+ [BestFriend]
+ internal abstract void InitializeNextPass(Row row, RoleMappedSchema schema);
///
/// Call the getters once, and process the input as necessary.
@@ -327,15 +337,15 @@ protected virtual List GetWarningsCore()
// When a new value is encountered, it uses a callback for creating a new aggregator.
protected abstract class AggregatorDictionaryBase
{
- protected Row Row;
- protected readonly Func CreateAgg;
- protected readonly RoleMappedSchema Schema;
+ private protected Row Row;
+ private protected readonly Func CreateAgg;
+ private protected readonly RoleMappedSchema Schema;
public string ColName { get; }
public abstract int Count { get; }
- protected AggregatorDictionaryBase(RoleMappedSchema schema, string stratCol, Func createAgg)
+ private protected AggregatorDictionaryBase(RoleMappedSchema schema, string stratCol, Func createAgg)
{
Contracts.AssertValue(schema);
Contracts.AssertNonWhiteSpace(stratCol);
@@ -351,7 +361,7 @@ protected AggregatorDictionaryBase(RoleMappedSchema schema, string stratCol, Fun
///
public abstract void Reset(Row row);
- public static AggregatorDictionaryBase Create(RoleMappedSchema schema, string stratCol, ColumnType stratType,
+ internal static AggregatorDictionaryBase Create(RoleMappedSchema schema, string stratCol, ColumnType stratType,
Func createAgg)
{
Contracts.AssertNonWhiteSpace(stratCol);
@@ -438,7 +448,8 @@ public override IEnumerable GetAll()
public abstract class RowToRowEvaluatorBase : EvaluatorBase
where TAgg : EvaluatorBase.AggregatorBase
{
- protected RowToRowEvaluatorBase(IHostEnvironment env, string registrationName)
+ [BestFriend]
+ private protected RowToRowEvaluatorBase(IHostEnvironment env, string registrationName)
: base(env, registrationName)
{
}
diff --git a/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs
index 34e8bda9f4..84221e8a6a 100644
--- a/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs
@@ -74,19 +74,26 @@ public abstract class ArgumentsBase : EvaluateInputBase
public string[] StratColumn;
}
- public static RoleMappedSchema.ColumnRole Strat = "Strat";
- protected readonly IHost Host;
+ internal static RoleMappedSchema.ColumnRole Strat = "Strat";
+ [BestFriend]
+ private protected readonly IHost Host;
- protected readonly string ScoreColumnKind;
- protected readonly string ScoreCol;
- protected readonly string LabelCol;
- protected readonly string WeightCol;
- protected readonly string[] StratCols;
+ [BestFriend]
+ private protected readonly string ScoreColumnKind;
+ [BestFriend]
+ private protected readonly string ScoreCol;
+ [BestFriend]
+ private protected readonly string LabelCol;
+ [BestFriend]
+ private protected readonly string WeightCol;
+ [BestFriend]
+ private protected readonly string[] StratCols;
[BestFriend]
private protected abstract IEvaluator Evaluator { get; }
- protected MamlEvaluatorBase(ArgumentsBase args, IHostEnvironment env, string scoreColumnKind, string registrationName)
+ [BestFriend]
+ private protected MamlEvaluatorBase(ArgumentsBase args, IHostEnvironment env, string scoreColumnKind, string registrationName)
{
Contracts.CheckValue(env, nameof(env));
Host = env.Register(registrationName);
@@ -103,7 +110,8 @@ Dictionary IEvaluator.Evaluate(RoleMappedData data)
return Evaluator.Evaluate(data);
}
- protected IEnumerable> GetInputColumnRoles(RoleMappedSchema schema, bool needStrat = false, bool needName = false)
+ [BestFriend]
+ private protected IEnumerable> GetInputColumnRoles(RoleMappedSchema schema, bool needStrat = false, bool needName = false)
{
Host.CheckValue(schema, nameof(schema));
@@ -122,7 +130,8 @@ Dictionary IEvaluator.Evaluate(RoleMappedData data)
/// The base class ipmlementation gets the score column, the label column (if exists) and the weight column (if exists).
/// Override if additional columns are needed.
///
- protected virtual IEnumerable> GetInputColumnRolesCore(RoleMappedSchema schema)
+ [BestFriend]
+ private protected virtual IEnumerable> GetInputColumnRolesCore(RoleMappedSchema schema)
{
// Get the score column information.
var scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, schema.Schema, ScoreCol, nameof(ArgumentsBase.ScoreColumn),
@@ -154,7 +163,8 @@ void IMamlEvaluator.PrintFoldResults(IChannel ch, Dictionary
/// This method simply prints the overall metrics using EvaluateUtils.PrintConfusionMatrixAndPerFoldResults.
/// Override if something else is needed.
///
- protected virtual void PrintFoldResultsCore(IChannel ch, Dictionary metrics)
+ [BestFriend]
+ private protected virtual void PrintFoldResultsCore(IChannel ch, Dictionary metrics)
{
ch.AssertValue(ch);
ch.AssertValue(metrics);
@@ -177,12 +187,14 @@ IDataView IMamlEvaluator.GetOverallResults(params IDataView[] metrics)
return GetOverallResultsCore(overall);
}
- protected virtual IDataView CombineOverallMetricsCore(IDataView[] metrics)
+ [BestFriend]
+ private protected virtual IDataView CombineOverallMetricsCore(IDataView[] metrics)
{
return EvaluateUtils.ConcatenateOverallMetrics(Host, metrics);
}
- protected virtual IDataView GetOverallResultsCore(IDataView overall)
+ [BestFriend]
+ private protected virtual IDataView GetOverallResultsCore(IDataView overall)
{
return overall;
}
@@ -198,7 +210,8 @@ void IMamlEvaluator.PrintAdditionalMetrics(IChannel ch, params Dictionary
- protected virtual void PrintAdditionalMetricsCore(IChannel ch, Dictionary[] metrics)
+ [BestFriend]
+ private protected virtual void PrintAdditionalMetricsCore(IChannel ch, Dictionary[] metrics)
{
}
@@ -259,7 +272,8 @@ private IDataView WrapPerInstance(RoleMappedData perInst)
/// It should be overridden only if additional processing is needed, such as dropping slots in the "top k scores" column
/// in the multi-class case.
///
- protected virtual IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMappedSchema schema)
+ [BestFriend]
+ private protected virtual IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMappedSchema schema)
{
return perInst;
}
@@ -276,6 +290,7 @@ IDataView IMamlEvaluator.GetPerInstanceDataViewToSave(RoleMappedData perInstance
/// the columns generated by the corresponding , or any of the input columns used by
/// it. The Name and Weight columns should not be included, since the base class includes them automatically.
///
- protected abstract IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema);
+ [BestFriend]
+ private protected abstract IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema);
}
}
diff --git a/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs
index e5426f2b04..de7be6ec84 100644
--- a/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs
@@ -58,7 +58,7 @@ public enum Metrics
LogLossReduction,
}
- public const string LoadName = "MultiClassClassifierEvaluator";
+ internal const string LoadName = "MultiClassClassifierEvaluator";
private readonly int? _outputTopKAcc;
private readonly bool _names;
@@ -72,7 +72,7 @@ public MultiClassClassifierEvaluator(IHostEnvironment env, Arguments args)
_names = args.Names;
}
- protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
+ private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
{
var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score);
var t = score.Type;
@@ -84,7 +84,7 @@ protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
throw Host.Except("Label column '{0}' has type {1} but must be a float or a known-cardinality key", schema.Label.Name, t);
}
- protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName)
+ private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName)
{
var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score);
Host.Assert(score.Type.VectorSize > 0);
@@ -137,7 +137,7 @@ public override IEnumerable GetOverallMetricColumns()
yield return new MetricColumn("LogLossReduction", LogLossReduction);
}
- protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries,
+ private protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries,
out Action, Aggregator> addAgg, out Func> consolidate)
{
var stratCol = new List();
@@ -387,7 +387,7 @@ public Aggregator(IHostEnvironment env, ReadOnlyMemory[] classNames, int s
ClassNames = classNames;
}
- public override void InitializeNextPass(Row row, RoleMappedSchema schema)
+ internal override void InitializeNextPass(Row row, RoleMappedSchema schema)
{
Host.Assert(PassNum < 1);
Host.AssertValue(schema.Label);
@@ -871,7 +871,7 @@ public MultiClassMamlEvaluator(IHostEnvironment env, Arguments args)
_evaluator = new MultiClassClassifierEvaluator(Host, evalArgs);
}
- protected override void PrintFoldResultsCore(IChannel ch, Dictionary metrics)
+ private protected override void PrintFoldResultsCore(IChannel ch, Dictionary metrics)
{
Host.AssertValue(metrics);
@@ -901,7 +901,7 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary();
@@ -940,7 +940,7 @@ protected override IDataView CombineOverallMetricsCore(IDataView[] metrics)
return base.CombineOverallMetricsCore(views);
}
- protected override IDataView GetOverallResultsCore(IDataView overall)
+ private protected override IDataView GetOverallResultsCore(IDataView overall)
{
// Change the name of the Top-k-accuracy column.
if (_outputTopKAcc != null)
@@ -978,7 +978,7 @@ public override IEnumerable GetOverallMetricColumns()
yield return new MetricColumn("LogLossReduction", MultiClassClassifierEvaluator.LogLossReduction);
}
- protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema)
+ private protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema)
{
Host.CheckValue(schema, nameof(schema));
Host.CheckParam(schema.Label != null, nameof(schema), "Schema must contain a label column");
@@ -994,7 +994,7 @@ protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSch
}
// Multi-class evaluator adds four per-instance columns: "Assigned", "Top scores", "Top classes" and "Log-loss".
- protected override IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMappedSchema schema)
+ private protected override IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMappedSchema schema)
{
// If the label column is a key without key values, convert it to I8, just for saving the per-instance
// text file, since if there are different key counts the columns cannot be appended.
diff --git a/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs
index 86232f647f..2f1e8769b8 100644
--- a/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs
@@ -56,7 +56,7 @@ private protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchem
return new MultiOutputRegressionPerInstanceEvaluator(Host, schema.Schema, scoreInfo.Name, schema.Label.Name);
}
- protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
+ private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
{
var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score);
var t = score.Type;
@@ -68,7 +68,7 @@ protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
throw Host.Except("Label column '{0}' has type '{1}' but must be a known-size vector of R4 or R8", schema.Label.Name, t);
}
- protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName)
+ private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName)
{
var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score);
Host.Assert(score.Type.VectorSize > 0);
@@ -92,7 +92,7 @@ public override IEnumerable GetOverallMetricColumns()
groupName: "label", nameFormat: string.Format("{0} (Label_{{0}}", PerLabelLoss));
}
- protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries,
+ private protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries,
out Action, Aggregator> addAgg, out Func> consolidate)
{
var stratCol = new List();
@@ -299,7 +299,7 @@ public Aggregator(IHostEnvironment env, IRegressionLoss lossFunction, int size,
WeightedCounters = Weighted ? new Counters(lossFunction, _size) : null;
}
- public override void InitializeNextPass(Row row, RoleMappedSchema schema)
+ internal override void InitializeNextPass(Row row, RoleMappedSchema schema)
{
Contracts.Assert(PassNum < 1);
Contracts.AssertValue(schema.Label);
@@ -636,7 +636,7 @@ public MultiOutputRegressionMamlEvaluator(IHostEnvironment env, Arguments args)
_evaluator = new MultiOutputRegressionEvaluator(Host, evalArgs);
}
- protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema)
+ private protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema)
{
Host.CheckValue(schema, nameof(schema));
Host.CheckParam(schema.Label != null, nameof(schema), "Schema must contain a label column");
@@ -658,7 +658,7 @@ protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSch
}
// The multi-output regression evaluator prints only the per-label metrics for each fold.
- protected override void PrintFoldResultsCore(IChannel ch, Dictionary metrics)
+ private protected override void PrintFoldResultsCore(IChannel ch, Dictionary metrics)
{
IDataView fold;
if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out fold))
diff --git a/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs
index d91dc943ac..e7f342d1a1 100644
--- a/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs
@@ -53,7 +53,7 @@ private protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchem
return new QuantileRegressionPerInstanceEvaluator(Host, schema.Schema, scoreInfo.Name, schema.Label.Name, scoreSize, quantiles);
}
- protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
+ private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
{
var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score);
var t = score.Type;
@@ -68,7 +68,7 @@ protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
throw Host.Except("Label column '{0}' has type '{1}' but must be R4", schema.Label.Name, t);
}
- protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName)
+ private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName)
{
var scoreInfo = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score);
var t = scoreInfo.Type;
@@ -487,7 +487,7 @@ public QuantileRegressionMamlEvaluator(IHostEnvironment env, Arguments args)
_evaluator = new QuantileRegressionEvaluator(Host, evalArgs);
}
- protected override void PrintFoldResultsCore(IChannel ch, Dictionary metrics)
+ private protected override void PrintFoldResultsCore(IChannel ch, Dictionary metrics)
{
ch.AssertValue(metrics);
@@ -505,7 +505,7 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary GetOverallMetricColumns()
yield return new MetricColumn("RSquared", QuantileRegressionEvaluator.RSquared);
}
- protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema)
+ private protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema)
{
Host.CheckValue(schema, nameof(schema));
Host.CheckParam(schema.Label != null, nameof(schema), "Schema must contain a label column");
diff --git a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs
index 5523057462..6872890699 100644
--- a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs
@@ -84,7 +84,7 @@ public RankerEvaluator(IHostEnvironment env, Arguments args)
_labelGains = labelGains.ToArray();
}
- protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
+ private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
{
var t = schema.Label.Type;
if (t != NumberType.Float && !t.IsKey)
@@ -100,7 +100,7 @@ protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
}
}
- protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema)
+ private protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema)
{
var t = schema.Group.Type;
if (!t.IsKey)
@@ -112,13 +112,13 @@ protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema)
}
// Add also the group column.
- protected override Func GetActiveColsCore(RoleMappedSchema schema)
+ private protected override Func GetActiveColsCore(RoleMappedSchema schema)
{
var pred = base.GetActiveColsCore(schema);
return i => i == schema.Group.Index || pred(i);
}
- protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName)
+ private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName)
{
return new Aggregator(Host, _labelGains, _truncationLevel, _groupSummary, schema.Weight != null, stratName);
}
@@ -147,7 +147,7 @@ public override IEnumerable GetOverallMetricColumns()
groupName: "at", nameFormat: string.Format("{0} @{{0}}", MaxDcg));
}
- protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries,
+ private protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries,
out Action, Aggregator> addAgg, out Func> consolidate)
{
var stratCol = new List();
@@ -440,7 +440,7 @@ public Aggregator(IHostEnvironment env, Double[] labelGains, int truncationLevel
GroupId = new List>();
}
- public override void InitializeNextPass(Row row, RoleMappedSchema schema)
+ internal override void InitializeNextPass(Row row, RoleMappedSchema schema)
{
Contracts.Assert(PassNum < 1);
Contracts.AssertValue(schema.Label);
@@ -882,14 +882,14 @@ public RankerMamlEvaluator(IHostEnvironment env, Arguments args)
_groupIdCol = args.GroupIdColumn;
}
- protected override IEnumerable> GetInputColumnRolesCore(RoleMappedSchema schema)
+ private protected override IEnumerable> GetInputColumnRolesCore(RoleMappedSchema schema)
{
var cols = base.GetInputColumnRolesCore(schema);
var groupIdCol = EvaluateUtils.GetColName(_groupIdCol, schema.Group, DefaultColumnNames.GroupId);
return cols.Prepend(RoleMappedSchema.ColumnRole.Group.Bind(groupIdCol));
}
- protected override void PrintAdditionalMetricsCore(IChannel ch, Dictionary[] metrics)
+ private protected override void PrintAdditionalMetricsCore(IChannel ch, Dictionary[] metrics)
{
ch.AssertNonEmpty(metrics);
@@ -929,7 +929,7 @@ private bool TryGetGroupSummaryMetrics(Dictionary[] metrics,
return true;
}
- protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema)
+ private protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema)
{
Host.CheckValue(schema, nameof(schema));
Host.CheckValue(schema.Label, nameof(schema), "Data must contain a label column");
diff --git a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs
index 48d94e3159..c4004a494e 100644
--- a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs
@@ -52,7 +52,7 @@ public RegressionEvaluator(IHostEnvironment env, Arguments args)
{
}
- protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
+ private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
{
var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score);
var t = score.Type;
@@ -64,7 +64,7 @@ protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
throw Host.Except("Label column '{0}' has type '{1}' but must be R4", schema.Label.Name, t);
}
- protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName)
+ private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName)
{
return new Aggregator(Host, LossFunction, schema.Weight != null, stratName);
}
@@ -350,7 +350,7 @@ public RegressionMamlEvaluator(IHostEnvironment env, Arguments args)
_evaluator = new RegressionEvaluator(Host, evalArgs);
}
- protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema)
+ private protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema)
{
Host.CheckValue(schema, nameof(schema));
Host.CheckParam(schema.Label != null, nameof(schema), "Schema must contain a label column");
diff --git a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluatorBase.cs b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluatorBase.cs
index 8343e9765e..c10cb87903 100644
--- a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluatorBase.cs
+++ b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluatorBase.cs
@@ -38,12 +38,13 @@ protected RegressionLossEvaluatorBase(ArgumentsBase args, IHostEnvironment env,
public abstract class RegressionEvaluatorBase : RegressionLossEvaluatorBase
where TAgg : RegressionEvaluatorBase.RegressionAggregatorBase
{
- protected RegressionEvaluatorBase(ArgumentsBase args, IHostEnvironment env, string registrationName)
+ [BestFriend]
+ private protected RegressionEvaluatorBase(ArgumentsBase args, IHostEnvironment env, string registrationName)
: base(args, env, registrationName)
{
}
- protected override void GetAggregatorConsolidationFuncs(TAgg aggregator, AggregatorDictionaryBase[] dictionaries,
+ private protected override void GetAggregatorConsolidationFuncs(TAgg aggregator, AggregatorDictionaryBase[] dictionaries,
out Action, TAgg> addAgg, out Func> consolidate)
{
var stratCol = new List();
@@ -184,7 +185,8 @@ public void Update(ref TScore score, float label, float weight, ref TMetrics los
public abstract CountersBase UnweightedCounters { get; }
public abstract CountersBase WeightedCounters { get; }
- protected RegressionAggregatorBase(IHostEnvironment env, IRegressionLoss lossFunction, bool weighted, string stratName)
+ [BestFriend]
+ private protected RegressionAggregatorBase(IHostEnvironment env, IRegressionLoss lossFunction, bool weighted, string stratName)
: base(env, stratName)
{
Host.AssertValue(lossFunction);
@@ -192,7 +194,7 @@ protected RegressionAggregatorBase(IHostEnvironment env, IRegressionLoss lossFun
Weighted = weighted;
}
- public override void InitializeNextPass(Row row, RoleMappedSchema schema)
+ internal override void InitializeNextPass(Row row, RoleMappedSchema schema)
{
Contracts.Assert(PassNum < 1);
Contracts.AssertValue(schema.Label);
diff --git a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs
index ae02024084..ee800ea4a0 100644
--- a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs
+++ b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs
@@ -125,7 +125,8 @@ private static ISchemaBoundMapper WrapCore(IHostEnvironment env, ISchemaBound
return MultiClassClassifierScorer.LabelNameBindableMapper.CreateBound(env, (ISchemaBoundRowMapper)mapper, type as VectorType, getter, MetadataUtils.Kinds.TrainingLabelValues, CanWrap);
}
- public BinaryClassifierScorer(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema)
+ [BestFriend]
+ internal BinaryClassifierScorer(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema)
: base(args, env, data, WrapIfNeeded(env, mapper, trainSchema), trainSchema, RegistrationName, MetadataUtils.Const.ScoreColumnKind.BinaryClassification,
Contracts.CheckRef(args, nameof(args)).ThresholdColumn, OutputTypeMatches, GetPredColType)
{
@@ -168,7 +169,7 @@ public static BinaryClassifierScorer Create(IHostEnvironment env, ModelLoadConte
return h.Apply("Loading Model", ch => new BinaryClassifierScorer(h, ctx, input));
}
- protected override void SaveCore(ModelSaveContext ctx)
+ private protected override void SaveCore(ModelSaveContext ctx)
{
Contracts.AssertValue(ctx);
ctx.SetVersionInfo(GetVersionInfo());
@@ -269,7 +270,7 @@ private void GetPredictedLabelCoreAsKey(Float score, ref uint value)
value = (uint)(score > _threshold ? 2 : score <= _threshold ? 1 : 0);
}
- protected override JToken PredictedLabelPfa(string[] mapperOutputs)
+ private protected override JToken PredictedLabelPfa(string[] mapperOutputs)
{
Contracts.CheckParam(Utils.Size(mapperOutputs) >= 1, nameof(mapperOutputs));
diff --git a/src/Microsoft.ML.Data/Scorers/ClusteringScorer.cs b/src/Microsoft.ML.Data/Scorers/ClusteringScorer.cs
index ce1e447de7..1b5aa01c05 100644
--- a/src/Microsoft.ML.Data/Scorers/ClusteringScorer.cs
+++ b/src/Microsoft.ML.Data/Scorers/ClusteringScorer.cs
@@ -44,7 +44,8 @@ private static VersionInfo GetVersionInfo()
private const string RegistrationName = "ClusteringScore";
- public ClusteringScorer(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema)
+ [BestFriend]
+ internal ClusteringScorer(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema)
: base(args, env, data, mapper, trainSchema, RegistrationName, MetadataUtils.Const.ScoreColumnKind.Clustering,
MetadataUtils.Const.ScoreValueKind.Score, OutputTypeMatches, GetPredColType)
{
@@ -73,7 +74,7 @@ public static ClusteringScorer Create(IHostEnvironment env, ModelLoadContext ctx
return h.Apply("Loading Model", ch => new ClusteringScorer(h, ctx, input));
}
- protected override void SaveCore(ModelSaveContext ctx)
+ private protected override void SaveCore(ModelSaveContext ctx)
{
Contracts.AssertValue(ctx);
ctx.SetVersionInfo(GetVersionInfo());
@@ -127,7 +128,7 @@ protected override Delegate GetPredictedLabelGetter(Row output, out Delegate sco
return predFn;
}
- protected override JToken PredictedLabelPfa(string[] mapperOutputs)
+ private protected override JToken PredictedLabelPfa(string[] mapperOutputs)
{
Contracts.Assert(Utils.Size(mapperOutputs) == 1);
return PfaUtils.Call("a.argmax", mapperOutputs[0]);
diff --git a/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculationTransform.cs b/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculationTransform.cs
index e8090fc614..ceea843335 100644
--- a/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculationTransform.cs
+++ b/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculationTransform.cs
@@ -64,7 +64,8 @@ public sealed class Arguments : ScorerArgumentsBase
// REVIEW: the scorer currently ignores the 'suffix' argument from the base class. It should respect it.
}
- public static IDataScorerTransform Create(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema)
+ [BestFriend]
+ internal static IDataScorerTransform Create(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(data, nameof(data));
@@ -82,7 +83,8 @@ public static IDataScorerTransform Create(IHostEnvironment env, Arguments args,
return scoredPipe;
}
- public static ISchemaBindableMapper Create(IHostEnvironment env, Arguments args, IPredictor predictor)
+ [BestFriend]
+ internal static ISchemaBindableMapper Create(IHostEnvironment env, Arguments args, IPredictor predictor)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(args, nameof(args));
@@ -116,7 +118,10 @@ public static IDataScorerTransform Create(IHostEnvironment env, Arguments args,
return Create(env, args, data, boundMapper, null);
}
- public static ISchemaBindableMapper Create(IHostEnvironment env, ModelLoadContext ctx)
+ ///
+ /// Create method corresponding to .
+ ///
+ private static ISchemaBindableMapper Create(IHostEnvironment env, ModelLoadContext ctx)
{
return new BindableMapper(env, ctx);
}
diff --git a/src/Microsoft.ML.Data/Scorers/GenericScorer.cs b/src/Microsoft.ML.Data/Scorers/GenericScorer.cs
index f301792603..8cf6eaaf7d 100644
--- a/src/Microsoft.ML.Data/Scorers/GenericScorer.cs
+++ b/src/Microsoft.ML.Data/Scorers/GenericScorer.cs
@@ -138,7 +138,7 @@ private static VersionInfo GetVersionInfo()
private const string RegistrationName = "GenericScore";
private readonly Bindings _bindings;
- protected override BindingsBase GetBindings() => _bindings;
+ private protected override BindingsBase GetBindings() => _bindings;
public override Schema OutputSchema { get; }
@@ -149,7 +149,8 @@ private static VersionInfo GetVersionInfo()
///
/// The entry point for creating a .
///
- public GenericScorer(IHostEnvironment env, ScorerArgumentsBase args, IDataView data,
+ [BestFriend]
+ internal GenericScorer(IHostEnvironment env, ScorerArgumentsBase args, IDataView data,
ISchemaBoundMapper mapper, RoleMappedSchema trainSchema)
: base(env, data, RegistrationName, Contracts.CheckRef(mapper, nameof(mapper)).Bindable)
{
@@ -199,7 +200,7 @@ public static GenericScorer Create(IHostEnvironment env, ModelLoadContext ctx, I
return h.Apply("Loading Model", ch => new GenericScorer(h, ctx, input));
}
- protected override void SaveCore(ModelSaveContext ctx)
+ private protected override void SaveCore(ModelSaveContext ctx)
{
Contracts.AssertValue(ctx);
ctx.SetVersionInfo(GetVersionInfo());
diff --git a/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs
index a8ce0092c0..7de5b0cb5a 100644
--- a/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs
+++ b/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs
@@ -79,7 +79,6 @@ public sealed class LabelNameBindableMapper : ISchemaBindableMapper, ICanSaveMod
public VectorType Type => _type;
bool ICanSavePfa.CanSavePfa => (_bindable as ICanSavePfa)?.CanSavePfa == true;
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => (_bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true;
- public ISchemaBindableMapper InnerBindable => _bindable;
private static VersionInfo GetVersionInfo()
{
@@ -138,8 +137,6 @@ private LabelNameBindableMapper(IHost host, ModelLoadContext ctx)
ctx.LoadNonEmptyString() : MetadataUtils.Kinds.SlotNames;
}
- public ISchemaBindableMapper Clone(ISchemaBindableMapper inner) => new LabelNameBindableMapper(_host, inner, _type, _getter, _metadataKind, _canWrap);
-
private Delegate DecodeInit(object value)
{
_host.CheckDecode(value is VBuffer);
@@ -148,7 +145,10 @@ private Delegate DecodeInit(object value)
return buffGetter;
}
- public static ISchemaBindableMapper Create(IHostEnvironment env, ModelLoadContext ctx)
+ ///
+ /// Method corresponding to .
+ ///
+ private static ISchemaBindableMapper Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
var h = env.Register(LoaderSignature);
@@ -211,7 +211,7 @@ bool IBindableCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, s
return ((IBindableCanSaveOnnx)_bindable).SaveAsOnnx(ctx, schema, outputNames);
}
- public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema)
+ ISchemaBoundMapper ISchemaBindableMapper.Bind(IHostEnvironment env, RoleMappedSchema schema)
{
var innerBound = _bindable.Bind(env, schema);
if (_canWrap != null && !_canWrap(innerBound, _type))
@@ -220,7 +220,7 @@ public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema)
return Utils.MarshalInvoke(CreateBound, _type.ItemType.RawType, env, (ISchemaBoundRowMapper)innerBound, _type, _getter, _metadataKind, _canWrap);
}
- public static ISchemaBoundMapper CreateBound(IHostEnvironment env, ISchemaBoundRowMapper mapper, VectorType type, Delegate getter,
+ internal static ISchemaBoundMapper CreateBound(IHostEnvironment env, ISchemaBoundRowMapper mapper, VectorType type, Delegate getter,
string metadataKind, Func canWrap)
{
Contracts.AssertValue(env);
@@ -393,7 +393,7 @@ private static ISchemaBoundMapper WrapIfNeeded(IHostEnvironment env, ISchemaBoun
/// from the model of a bindable mapper)
/// Whether we can call with
/// this mapper and expect it to succeed
- public static bool CanWrap(ISchemaBoundMapper mapper, ColumnType labelNameType)
+ private static bool CanWrap(ISchemaBoundMapper mapper, ColumnType labelNameType)
{
Contracts.AssertValue(mapper);
Contracts.AssertValue(labelNameType);
@@ -414,7 +414,7 @@ public static bool CanWrap(ISchemaBoundMapper mapper, ColumnType labelNameType)
return labelNameType.IsVector && labelNameType.VectorSize == scoreType.VectorSize;
}
- public static ISchemaBoundMapper WrapCore(IHostEnvironment env, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema)
+ private static ISchemaBoundMapper WrapCore(IHostEnvironment env, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema)
{
Contracts.AssertValue(env);
env.AssertValue(mapper);
@@ -436,7 +436,8 @@ public static ISchemaBoundMapper WrapCore(IHostEnvironment env, ISchemaBoundM
return LabelNameBindableMapper.CreateBound(env, (ISchemaBoundRowMapper)mapper, type as VectorType, getter, MetadataUtils.Kinds.SlotNames, CanWrap);
}
- public MultiClassClassifierScorer(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema)
+ [BestFriend]
+ internal MultiClassClassifierScorer(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema)
: base(args, env, data, WrapIfNeeded(env, mapper, trainSchema), trainSchema, RegistrationName, MetadataUtils.Const.ScoreColumnKind.MultiClassClassification,
MetadataUtils.Const.ScoreValueKind.Score, OutputTypeMatches, GetPredColType)
{
@@ -454,7 +455,10 @@ private MultiClassClassifierScorer(IHost host, ModelLoadContext ctx, IDataView i
//
}
- public static MultiClassClassifierScorer Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
+ ///
+ /// Corresponds to .
+ ///
+ private static MultiClassClassifierScorer Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
var h = env.Register(RegistrationName);
@@ -464,7 +468,7 @@ public static MultiClassClassifierScorer Create(IHostEnvironment env, ModelLoadC
return h.Apply("Loading Model", ch => new MultiClassClassifierScorer(h, ctx, input));
}
- protected override void SaveCore(ModelSaveContext ctx)
+ private protected override void SaveCore(ModelSaveContext ctx)
{
Contracts.AssertValue(ctx);
ctx.SetVersionInfo(GetVersionInfo());
@@ -518,7 +522,7 @@ protected override Delegate GetPredictedLabelGetter(Row output, out Delegate sco
return predFn;
}
- protected override JToken PredictedLabelPfa(string[] mapperOutputs)
+ private protected override JToken PredictedLabelPfa(string[] mapperOutputs)
{
Contracts.Assert(Utils.Size(mapperOutputs) == 1);
return PfaUtils.Call("a.argmax", mapperOutputs[0]);
diff --git a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs
index e6f044c48f..0a869d250b 100644
--- a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs
+++ b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs
@@ -30,7 +30,8 @@ public abstract class ThresholdArgumentsBase : ScorerArgumentsBase
public string ThresholdColumn = MetadataUtils.Const.ScoreValueKind.Score;
}
- protected sealed class BindingsImpl : BindingsBase
+ [BestFriend]
+ private protected sealed class BindingsImpl : BindingsBase
{
// Column index of the score column in Mapper's schema.
public readonly int ScoreColumnIndex;
@@ -272,15 +273,18 @@ public override Func GetActiveMapperColumns(bool[] active)
}
}
- protected readonly BindingsImpl Bindings;
- protected override BindingsBase GetBindings() => Bindings;
+ [BestFriend]
+ private protected readonly BindingsImpl Bindings;
+ [BestFriend]
+ private protected sealed override BindingsBase GetBindings() => Bindings;
public override Schema OutputSchema { get; }
bool ICanSavePfa.CanSavePfa => (Bindable as ICanSavePfa)?.CanSavePfa == true;
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => (Bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true;
- protected PredictedLabelScorerBase(ScorerArgumentsBase args, IHostEnvironment env, IDataView data,
+ [BestFriend]
+ private protected PredictedLabelScorerBase(ScorerArgumentsBase args, IHostEnvironment env, IDataView data,
ISchemaBoundMapper mapper, RoleMappedSchema trainSchema, string registrationName, string scoreColKind, string scoreColName,
Func outputTypeMatches, Func getPredColType)
: base(env, data, registrationName, Contracts.CheckRef(mapper, nameof(mapper)).Bindable)
@@ -314,7 +318,8 @@ protected PredictedLabelScorerBase(IHostEnvironment env, PredictedLabelScorerBas
OutputSchema = Schema.Create(Bindings);
}
- protected PredictedLabelScorerBase(IHost host, ModelLoadContext ctx, IDataView input,
+ [BestFriend]
+ private protected PredictedLabelScorerBase(IHost host, ModelLoadContext ctx, IDataView input,
Func outputTypeMatches, Func getPredColType)
: base(host, ctx, input)
{
@@ -327,7 +332,7 @@ protected PredictedLabelScorerBase(IHost host, ModelLoadContext ctx, IDataView i
OutputSchema = Schema.Create(Bindings);
}
- protected override void SaveCore(ModelSaveContext ctx)
+ private protected override void SaveCore(ModelSaveContext ctx)
{
Host.AssertValue(ctx);
Bindings.Save(ctx);
@@ -360,7 +365,8 @@ void ISaveAsPfa.SaveAsPfa(BoundPfaContext ctx)
ctx.DeclareVar(derivedName, predictedLabelExpression);
}
- protected abstract JToken PredictedLabelPfa(string[] mapperOutputs);
+ [BestFriend]
+ private protected abstract JToken PredictedLabelPfa(string[] mapperOutputs);
void ISaveAsOnnx.SaveAsOnnx(OnnxContext ctx) => SaveAsOnnxCore(ctx);
diff --git a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs
index 25d31e13bc..08bf96af22 100644
--- a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs
+++ b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs
@@ -47,7 +47,8 @@ public abstract class PredictionTransformerBase : IPredictionTr
protected const string DirModel = "Model";
protected const string DirTransSchema = "TrainSchema";
protected readonly IHost Host;
- protected ISchemaBindableMapper BindableMapper;
+ [BestFriend]
+ private protected ISchemaBindableMapper BindableMapper;
protected Schema TrainSchema;
public bool IsRowToRowMapper => true;
diff --git a/src/Microsoft.ML.Data/Scorers/QuantileRegressionScorer.cs b/src/Microsoft.ML.Data/Scorers/QuantileRegressionScorer.cs
index 032df06b99..2dec3a0e4a 100644
--- a/src/Microsoft.ML.Data/Scorers/QuantileRegressionScorer.cs
+++ b/src/Microsoft.ML.Data/Scorers/QuantileRegressionScorer.cs
@@ -18,7 +18,7 @@
namespace Microsoft.ML.Runtime.Data
{
- public static class QuantileRegressionScorerTransform
+ internal static class QuantileRegressionScorerTransform
{
public sealed class Arguments : ScorerArgumentsBase
{
@@ -26,12 +26,18 @@ public sealed class Arguments : ScorerArgumentsBase
public string Quantiles = "0,0.25,0.5,0.75,1";
}
- public static IDataScorerTransform Create(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema)
+ ///
+ /// Constructor corresponding to .
+ ///
+ private static IDataScorerTransform Create(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema)
{
return new GenericScorer(env, args, data, mapper, trainSchema);
}
- public static ISchemaBindableMapper Create(IHostEnvironment env, Arguments args, IPredictor predictor)
+ ///
+ /// Constructor corresponding to .
+ ///
+ private static ISchemaBindableMapper Create(IHostEnvironment env, Arguments args, IPredictor predictor)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(args, nameof(args));
diff --git a/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs b/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs
index 54375cfc25..697a23309e 100644
--- a/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs
+++ b/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs
@@ -19,7 +19,8 @@ namespace Microsoft.ML.Runtime.Data
///
public abstract class RowToRowScorerBase : RowToRowMapperTransformBase, IDataScorerTransform
{
- public abstract class BindingsBase : ScorerBindingsBase
+ [BestFriend]
+ private protected abstract class BindingsBase : ScorerBindingsBase
{
public readonly ISchemaBoundRowMapper RowMapper;
@@ -30,16 +31,19 @@ protected BindingsBase(Schema schema, ISchemaBoundRowMapper mapper, string suffi
}
}
- protected readonly ISchemaBindableMapper Bindable;
+ [BestFriend]
+ private protected readonly ISchemaBindableMapper Bindable;
- protected RowToRowScorerBase(IHostEnvironment env, IDataView input, string registrationName, ISchemaBindableMapper bindable)
+ [BestFriend]
+ private protected RowToRowScorerBase(IHostEnvironment env, IDataView input, string registrationName, ISchemaBindableMapper bindable)
: base(env, registrationName, input)
{
Contracts.AssertValue(bindable);
Bindable = bindable;
}
- protected RowToRowScorerBase(IHost host, ModelLoadContext ctx, IDataView input)
+ [BestFriend]
+ private protected RowToRowScorerBase(IHost host, ModelLoadContext ctx, IDataView input)
: base(host, input)
{
ctx.LoadModel(host, out Bindable, "SchemaBindableMapper");
@@ -56,7 +60,8 @@ public sealed override void Save(ModelSaveContext ctx)
///
/// The main save method handles saving the _bindable. This should do everything else.
///
- protected abstract void SaveCore(ModelSaveContext ctx);
+ [BestFriend]
+ private protected abstract void SaveCore(ModelSaveContext ctx);
///
/// For the ITransformTemplate implementation.
@@ -66,7 +71,8 @@ public sealed override void Save(ModelSaveContext ctx)
///
/// Derived classes provide the specific bindings object.
///
- protected abstract BindingsBase GetBindings();
+ [BestFriend]
+ private protected abstract BindingsBase GetBindings();
///
/// Produces the set of active columns for the scorer (as a bool[] of length bindings.ColumnCount),
@@ -296,11 +302,12 @@ public abstract class ScorerArgumentsBase
}
///
- /// Base bindings for a scorer based on an ISchemaBoundMapper. This assumes that input schema columns
+ /// Base bindings for a scorer based on an . This assumes that input schema columns
/// are echoed, followed by zero or more derived columns, followed by the mapper generated columns.
/// The names of the derived columns and mapper generated columns have an optional suffix appended.
///
- public abstract class ScorerBindingsBase : ColumnBindingsBase
+ [BestFriend]
+ internal abstract class ScorerBindingsBase : ColumnBindingsBase
{
///
/// The schema bound mapper.
diff --git a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs
index 5da07c1c95..e6142e3b50 100644
--- a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs
+++ b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs
@@ -34,7 +34,7 @@ namespace Microsoft.ML.Runtime.Data
///
/// This is a base class for wrapping s in an .
///
- public abstract class SchemaBindablePredictorWrapperBase : ISchemaBindableMapper, ICanSaveModel, ICanSaveSummary,
+ internal abstract class SchemaBindablePredictorWrapperBase : ISchemaBindableMapper, ICanSaveModel, ICanSaveSummary,
IBindableCanSavePfa, IBindableCanSaveOnnx
{
// The ctor guarantees that Predictor is non-null. It also ensures that either
@@ -115,7 +115,7 @@ bool IBindableCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, s
[BestFriend]
private protected virtual bool SaveAsOnnxCore(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames) => false;
- public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema)
+ ISchemaBoundMapper ISchemaBindableMapper.Bind(IHostEnvironment env, RoleMappedSchema schema)
{
Contracts.CheckValue(env, nameof(env));
@@ -142,7 +142,8 @@ public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema)
}
}
- protected abstract ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema);
+ [BestFriend]
+ private protected abstract ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema);
protected virtual Delegate GetPredictionGetter(Row input, int colSrc)
{
@@ -240,7 +241,7 @@ public Row GetRow(Row input, Func predicate)
/// This class is a wrapper for all s except for quantile regression predictors,
/// and calibrated binary classification predictors.
///
- public sealed class SchemaBindablePredictorWrapper : SchemaBindablePredictorWrapperBase
+ internal sealed class SchemaBindablePredictorWrapper : SchemaBindablePredictorWrapperBase
{
public const string LoaderSignature = "SchemaBindableWrapper";
private static VersionInfo GetVersionInfo()
@@ -316,7 +317,7 @@ private protected override bool SaveAsOnnxCore(OnnxContext ctx, RoleMappedSchema
return mapper.SaveAsOnnx(ctx, outputNames, ctx.GetVariableName(schema.Feature.Name));
}
- protected override ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema)
+ private protected override ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema)
{
var outputSchema = Schema.Create(new ScoreMapperSchema(ScoreType, _scoreColumnKind));
return new SingleValueRowMapper(schema, this, outputSchema);
@@ -352,7 +353,7 @@ private static string GetScoreColumnKind(IPredictor predictor)
/// This is an wrapper for calibrated binary classification predictors.
/// They need a separate wrapper because they return two values instead of one: the raw score and the probability.
///
- public sealed class SchemaBindableBinaryPredictorWrapper : SchemaBindablePredictorWrapperBase
+ internal sealed class SchemaBindableBinaryPredictorWrapper : SchemaBindablePredictorWrapperBase
{
public const string LoaderSignature = "BinarySchemaBindable";
private static VersionInfo GetVersionInfo()
@@ -450,7 +451,7 @@ private void CheckValid(out IValueMapperDist distMapper)
"Invalid probability type for the IValueMapperDist");
}
- protected override ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema)
+ private protected override ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema)
{
if (Predictor.PredictionKind != PredictionKind.BinaryClassification)
ch.Warning("Scoring predictor of kind '{0}' as '{1}'.", Predictor.PredictionKind, PredictionKind.BinaryClassification);
@@ -577,9 +578,10 @@ public Row GetRow(Row input, Func predicate)
///
/// This is an wrapper for quantile regression predictors. They need a separate
- /// wrapper because they need the quantiles to create the ISchemaBound.
+ /// wrapper because they need the quantiles to create the .
///
- public sealed class SchemaBindableQuantileRegressionPredictor : SchemaBindablePredictorWrapperBase
+ [BestFriend]
+ internal sealed class SchemaBindableQuantileRegressionPredictor : SchemaBindablePredictorWrapperBase
{
public const string LoaderSignature = "QuantileSchemaBindable";
private static VersionInfo GetVersionInfo()
@@ -650,7 +652,7 @@ public static SchemaBindableQuantileRegressionPredictor Create(IHostEnvironment
return new SchemaBindableQuantileRegressionPredictor(env, ctx);
}
- protected override ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema)
+ private protected override ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema)
{
return new SingleValueRowMapper(schema, this, Schema.Create(new SchemaImpl(ScoreType, _quantiles)));
}
diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs
index 15957c225c..28bca6c58f 100644
--- a/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs
+++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs
@@ -316,7 +316,7 @@ public static void LoadModel(ModelLoadContext ctx, int cv, out bool useLog, out
/// It tracks min, max, number of non-sparse values (vCount) and number of ProcessValue() calls (trainCount).
/// NaNs are ignored when updating min and max.
///
- public sealed class MinMaxSngAggregator : IColumnAggregator>
+ internal sealed class MinMaxSngAggregator : IColumnAggregator>
{
private readonly TFloat[] _min;
private readonly TFloat[] _max;
diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeUtils.cs b/src/Microsoft.ML.Data/Transforms/NormalizeUtils.cs
index 0f1ea0ef62..b26111b03c 100644
--- a/src/Microsoft.ML.Data/Transforms/NormalizeUtils.cs
+++ b/src/Microsoft.ML.Data/Transforms/NormalizeUtils.cs
@@ -40,7 +40,8 @@ internal interface IColumnFunctionBuilder
///
/// Interface to define an aggregate function over values
///
- public interface IColumnAggregator
+ [BestFriend]
+ internal interface IColumnAggregator
{
///
/// Updates the aggregate function with a value
@@ -68,7 +69,7 @@ internal interface IColumnFunction : ICanSaveModel
NormalizingTransformer.NormalizerModelParametersBase GetNormalizerModelParams();
}
- public static class NormalizeUtils
+ internal static class NormalizeUtils
{
///
/// Returns whether the feature column in the schema is indicated to be normalized. If the features column is not
diff --git a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransformer.cs b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransformer.cs
index e8ec5f4e53..dbc4e4d685 100644
--- a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransformer.cs
+++ b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransformer.cs
@@ -20,7 +20,8 @@
namespace Microsoft.ML.Transforms
{
- public static class ScoringTransformer
+ [BestFriend]
+ internal static class ScoringTransformer
{
public sealed class Arguments : TransformInputBase
{
diff --git a/src/Microsoft.ML.Data/Utilities/ModelFileUtils.cs b/src/Microsoft.ML.Data/Utilities/ModelFileUtils.cs
index e7a61c42d6..0b962f435e 100644
--- a/src/Microsoft.ML.Data/Utilities/ModelFileUtils.cs
+++ b/src/Microsoft.ML.Data/Utilities/ModelFileUtils.cs
@@ -20,7 +20,8 @@ namespace Microsoft.ML.Runtime.Model
///
/// This class provides utilities for loading components from the model file generated by MAML commands.
///
- public static class ModelFileUtils
+ [BestFriend]
+ internal static class ModelFileUtils
{
public const string DirPredictor = "Predictor";
public const string DirDataLoaderModel = "DataLoaderModel";
diff --git a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs
index d6717d7224..97ddf9ef43 100644
--- a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs
+++ b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs
@@ -218,7 +218,7 @@ protected override void SaveCore(ModelSaveContext ctx)
ctx.SaveModel(Combiner, "Combiner");
}
- public override ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema)
+ private protected override ISchemaBoundMapper BindCore(IHostEnvironment env, RoleMappedSchema schema)
{
return new Bound(this, schema);
}
@@ -558,7 +558,9 @@ public static SchemaBindablePipelineEnsembleBase Create(IHostEnvironment env, Mo
}
}
- public abstract ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema);
+ ISchemaBoundMapper ISchemaBindableMapper.Bind(IHostEnvironment env, RoleMappedSchema schema) => BindCore(env, schema);
+
+ private protected abstract ISchemaBoundMapper BindCore(IHostEnvironment env, RoleMappedSchema schema);
void ICanSaveSummary.SaveSummary(TextWriter writer, RoleMappedSchema schema)
{
diff --git a/src/Microsoft.ML.FastTree/TreeEnsemble/RegressionTree.cs b/src/Microsoft.ML.FastTree/TreeEnsemble/RegressionTree.cs
index 834a4188f1..7cd371ba91 100644
--- a/src/Microsoft.ML.FastTree/TreeEnsemble/RegressionTree.cs
+++ b/src/Microsoft.ML.FastTree/TreeEnsemble/RegressionTree.cs
@@ -1119,7 +1119,7 @@ public void RemapFeatures(int[] oldToNewFeatures)
/// A map of feature index (in the features array)
/// to the ID as it will be written in the file. This instance should be
/// used for all
- public void ToTreeEnsembleFormat(StringBuilder sbEvaluator, StringBuilder sbInput, FeaturesToContentMap featureContents,
+ internal void ToTreeEnsembleFormat(StringBuilder sbEvaluator, StringBuilder sbInput, FeaturesToContentMap featureContents,
ref int evaluatorCounter, Dictionary featureToId)
{
Contracts.AssertValue(sbEvaluator);
@@ -1259,7 +1259,7 @@ private void ToTreeEnsembleFormatForCategoricalSplit(StringBuilder sbEvaluator,
}
// prints the tree out as a string (in old Bing format used by LambdaMART and AdIndex)
- public string ToOldIni(FeatureNameCollection featureNames)
+ internal string ToOldIni(FeatureNameCollection featureNames)
{
// print the root node
StringBuilder output = new StringBuilder();
diff --git a/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsemble.cs b/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsemble.cs
index 08ddb00a32..b1a21f9541 100644
--- a/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsemble.cs
+++ b/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsemble.cs
@@ -107,7 +107,7 @@ public void RemapFeatures(int[] oldToNewFeatures)
///
/// returns the ensemble in the production TreeEnsemble format
///
- public string ToTreeEnsembleIni(FeaturesToContentMap fmap,
+ internal string ToTreeEnsembleIni(FeaturesToContentMap fmap,
string trainingParams, bool appendFeatureGain, bool includeZeroGainFeatures = true)
{
StringBuilder sbEvaluator = new StringBuilder();
@@ -306,7 +306,7 @@ public void GetOutputs(Dataset dataset, double[] outputs, int prefix)
Parallel.Invoke(new ParallelOptions { MaxDegreeOfParallelism = BlockingThreadPool.NumThreads }, actions);
}
- public string ToGainSummary(FeaturesToContentMap fmap, Dictionary featureToID, int prefix, bool includeZeroGainFeatures, bool normalize, int startingCommentNumber)
+ internal string ToGainSummary(FeaturesToContentMap fmap, Dictionary featureToID, int prefix, bool includeZeroGainFeatures, bool normalize, int startingCommentNumber)
{
if (_trees.Count == 0)
return string.Empty;
@@ -397,7 +397,7 @@ public FeatureToGainMap(IList trees, bool normalize)
/// A class that given either a
/// provides a mechanism for getting the corresponding input INI content for the features.
///
- public sealed class FeaturesToContentMap
+ internal sealed class FeaturesToContentMap
{
private readonly VBuffer> _content;
private readonly VBuffer> _names;
diff --git a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs
index b6ac50bb70..27951364fc 100644
--- a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs
+++ b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs
@@ -525,7 +525,7 @@ private void GetPathSlotNames(int col, ref VBuffer> dst)
dst = editor.Commit();
}
- public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema)
+ ISchemaBoundMapper ISchemaBindableMapper.Bind(IHostEnvironment env, RoleMappedSchema schema)
{
Contracts.AssertValue(env);
env.AssertValue(schema);
@@ -654,7 +654,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
vm.InputType.VectorSize, data.Schema.Feature.Type.VectorSize);
}
- var bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor);
+ ISchemaBindableMapper bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor);
var bound = bindable.Bind(env, data.Schema);
xf = new GenericScorer(env, scorerArgs, input, bound, data.Schema);
}
@@ -718,7 +718,7 @@ public static IDataTransform CreateForEntryPoint(IHostEnvironment env, Arguments
vm.InputType.VectorSize, data.Schema.Feature.Type.VectorSize);
}
- var bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor);
+ ISchemaBindableMapper bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor);
var bound = bindable.Bind(env, data.Schema);
return new GenericScorer(env, scorerArgs, data.Data, bound, data.Schema);
}
diff --git a/src/Microsoft.ML.PCA/PcaTrainer.cs b/src/Microsoft.ML.PCA/PcaTrainer.cs
index fee4a25d07..3cbb610e97 100644
--- a/src/Microsoft.ML.PCA/PcaTrainer.cs
+++ b/src/Microsoft.ML.PCA/PcaTrainer.cs
@@ -530,7 +530,7 @@ void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema)
}
}
- public IDataView GetSummaryDataView(RoleMappedSchema schema)
+ IDataView ICanGetSummaryAsIDataView.GetSummaryDataView(RoleMappedSchema schema)
{
var bldr = new ArrayDataViewBuilder(Host);
diff --git a/src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs b/src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs
index f33bc405b8..fb78fdcd33 100644
--- a/src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs
+++ b/src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs
@@ -261,7 +261,7 @@ private float Score(int columnIndex, int rowIndex)
/// Create a row mapper based on regression scorer. Because matrix factorization predictor maps a tuple of a row ID (u) and a column ID (v)
/// to the expected numerical value at the u-th row and the v-th column in the considered matrix, it is essentially a regressor.
///
- public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema)
+ ISchemaBoundMapper ISchemaBindableMapper.Bind(IHostEnvironment env, RoleMappedSchema schema)
{
Contracts.AssertValue(env);
env.AssertValue(schema);
diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs
index 636f49fb60..a417935210 100644
--- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs
+++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs
@@ -173,7 +173,7 @@ internal float CalculateResponse(ValueGetter>[] getters, VBuffer<
return modelResponse;
}
- public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema)
+ ISchemaBoundMapper ISchemaBindableMapper.Bind(IHostEnvironment env, RoleMappedSchema schema)
=> new FieldAwareFactorizationMachineScalarRowMapper(env, schema, Schema.Create(new BinaryClassifierSchema()), this);
internal void CopyLinearWeightsTo(float[] linearWeights)
diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs
index 6085504bb6..0be38f3547 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs
@@ -970,7 +970,7 @@ private void SaveLabelNames(ModelSaveContext ctx, BinaryWriter writer)
}
}
- public IDataView GetSummaryDataView(RoleMappedSchema schema)
+ IDataView ICanGetSummaryAsIDataView.GetSummaryDataView(RoleMappedSchema schema)
{
var bldr = new ArrayDataViewBuilder(Host);
diff --git a/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs b/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs
index 0a75f3b948..47c993dcfa 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs
@@ -6,6 +6,7 @@
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Internal.CpuMath;
+using Microsoft.ML.Runtime.Internal.Internallearn;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Learners;
using Microsoft.ML.Runtime.Model;
@@ -324,7 +325,7 @@ private List GetUnorderedCoefficientStatistics(LinearBina
///
/// Gets the coefficient statistics as an object.
///
- public CoefficientStatistics[] GetCoefficientStatistics(LinearBinaryModelParameters parent, RoleMappedSchema schema, int paramCountCap)
+ internal CoefficientStatistics[] GetCoefficientStatistics(LinearBinaryModelParameters parent, RoleMappedSchema schema, int paramCountCap)
{
Contracts.AssertValue(_env);
_env.CheckValue(parent, nameof(parent));
@@ -345,7 +346,7 @@ public CoefficientStatistics[] GetCoefficientStatistics(LinearBinaryModelParamet
return order.Prepend(new[] { new CoefficientStatistics("(Bias)", bias, stdError, zScore, pValue) }).ToArray();
}
- public void SaveText(TextWriter writer, LinearBinaryModelParameters parent, RoleMappedSchema schema, int paramCountCap)
+ internal void SaveText(TextWriter writer, LinearBinaryModelParameters parent, RoleMappedSchema schema, int paramCountCap)
{
Contracts.AssertValue(_env);
_env.CheckValue(writer, nameof(writer));
@@ -383,7 +384,10 @@ public void SaveText(TextWriter writer, LinearBinaryModelParameters parent, Role
writer.WriteLine("Significance codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1");
}
- public void SaveSummaryInKeyValuePairs(LinearBinaryModelParameters parent,
+ ///
+ /// Support method for linear models and .
+ ///
+ internal void SaveSummaryInKeyValuePairs(LinearBinaryModelParameters parent,
RoleMappedSchema schema, int paramCountCap, List> resultCollection)
{
Contracts.AssertValue(_env);