Skip to content

Commit 93a8fe3

Browse files
authored
Internalize RoleMappedSchema and implications thereof (#1902)
* Internalization of ISchemaBindable/BoundMappers * Internalize much of the IEvaluator implementation infrastructure. * Internalize other smaller scale usages of RoleMappedSchema. * Internalize RoleMappedSchema.
1 parent dbc38be commit 93a8fe3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+317
-228
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Data;
6+
using System;
7+
8+
namespace Microsoft.ML.Runtime.Data
9+
{
10+
/// <summary>
11+
/// This interface maps an input <see cref="Row"/> to an output <see cref="Row"/>. Typically, the output contains
12+
/// both the input columns and new columns added by the implementing class, although some implementations may
13+
/// return a subset of the input columns.
14+
/// This interface is similar to <see cref="ISchemaBoundRowMapper"/>, except it does not have any input role mappings,
15+
/// so to rebind, the same input column names must be used.
16+
/// Implementations of this interface are typically created over defined input <see cref="Schema"/>.
17+
/// </summary>
18+
public interface IRowToRowMapper
19+
{
20+
/// <summary>
21+
/// Mappers are defined as accepting inputs with this very specific schema.
22+
/// </summary>
23+
Schema InputSchema { get; }
24+
25+
/// <summary>
26+
/// Gets an instance of <see cref="Schema"/> which describes the columns' names and types in the output generated by this mapper.
27+
/// </summary>
28+
Schema OutputSchema { get; }
29+
30+
/// <summary>
31+
/// Given a predicate specifying which columns are needed, return a predicate indicating which input columns are
32+
/// needed. The domain of the function is defined over the indices of the columns of <see cref="Schema.Count"/>
33+
/// for <see cref="InputSchema"/>.
34+
/// </summary>
35+
Func<int, bool> GetDependencies(Func<int, bool> predicate);
36+
37+
/// <summary>
38+
/// Get an <see cref="Row"/> with the indicated active columns, based on the input <paramref name="input"/>.
39+
/// The active columns are those for which <paramref name="active"/> returns true. Getting values on inactive
40+
/// columns of the returned row will throw. Null predicates are disallowed.
41+
///
42+
/// The <see cref="Row.Schema"/> of <paramref name="input"/> should be the same object as
43+
/// <see cref="InputSchema"/>. Implementors of this method should throw if that is not the case. Conversely,
44+
/// the returned value must have the same schema as <see cref="OutputSchema"/>.
45+
///
46+
/// This method creates a live connection between the input <see cref="Row"/> and the output <see
47+
/// cref="Row"/>. In particular, when the getters of the output <see cref="Row"/> are invoked, they invoke the
48+
/// getters of the input row and base the output values on the current values of the input <see cref="Row"/>.
49+
/// The output <see cref="Row"/> values are re-computed when requested through the getters. Also, the returned
50+
/// <see cref="Row"/> will dispose <paramref name="input"/> when it is disposed.
51+
/// </summary>
52+
Row GetRow(Row input, Func<int, bool> active);
53+
}
54+
}

src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs

Lines changed: 6 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
// See the LICENSE file in the project root for more information.
44

55
using Microsoft.ML.Data;
6-
using System;
76
using System.Collections.Generic;
87

98
namespace Microsoft.ML.Runtime.Data
@@ -21,7 +20,8 @@ namespace Microsoft.ML.Runtime.Data
2120
/// for the output schema of the <see cref="ISchemaBoundMapper"/>. In case the <see cref="ISchemaBoundRowMapper"/> interface is implemented,
2221
/// the SimpleRow class can be used in the <see cref="IRowToRowMapper.GetRow"/> method.
2322
/// </summary>
24-
public interface ISchemaBindableMapper
23+
[BestFriend]
24+
internal interface ISchemaBindableMapper
2525
{
2626
ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema);
2727
}
@@ -30,7 +30,8 @@ public interface ISchemaBindableMapper
3030
/// This interface is used to map a schema from input columns to output columns. The <see cref="ISchemaBoundMapper"/> should keep track
3131
/// of the input columns that are needed for the mapping.
3232
/// </summary>
33-
public interface ISchemaBoundMapper
33+
[BestFriend]
34+
internal interface ISchemaBoundMapper
3435
{
3536
/// <summary>
3637
/// The <see cref="RoleMappedSchema"/> that was passed to the <see cref="ISchemaBoundMapper"/> in the binding process.
@@ -56,57 +57,13 @@ public interface ISchemaBoundMapper
5657
/// <summary>
5758
/// This interface combines <see cref="ISchemaBoundMapper"/> with <see cref="IRowToRowMapper"/>.
5859
/// </summary>
59-
public interface ISchemaBoundRowMapper : ISchemaBoundMapper, IRowToRowMapper
60+
[BestFriend]
61+
internal interface ISchemaBoundRowMapper : ISchemaBoundMapper, IRowToRowMapper
6062
{
6163
/// <summary>
6264
/// There are two schemas from <see cref="ISchemaBoundMapper"/> and <see cref="IRowToRowMapper"/>.
6365
/// Since the two parent schema's are identical in all derived classes, we merge them into <see cref="OutputSchema"/>.
6466
/// </summary>
6567
new Schema OutputSchema { get; }
6668
}
67-
68-
/// <summary>
69-
/// This interface maps an input <see cref="Row"/> to an output <see cref="Row"/>. Typically, the output contains
70-
/// both the input columns and new columns added by the implementing class, although some implementations may
71-
/// return a subset of the input columns.
72-
/// This interface is similar to <see cref="ISchemaBoundRowMapper"/>, except it does not have any input role mappings,
73-
/// so to rebind, the same input column names must be used.
74-
/// Implementations of this interface are typically created over defined input <see cref="Schema"/>.
75-
/// </summary>
76-
public interface IRowToRowMapper
77-
{
78-
/// <summary>
79-
/// Mappers are defined as accepting inputs with this very specific schema.
80-
/// </summary>
81-
Schema InputSchema { get; }
82-
83-
/// <summary>
84-
/// Gets an instance of <see cref="Schema"/> which describes the columns' names and types in the output generated by this mapper.
85-
/// </summary>
86-
Schema OutputSchema { get; }
87-
88-
/// <summary>
89-
/// Given a predicate specifying which columns are needed, return a predicate indicating which input columns are
90-
/// needed. The domain of the function is defined over the indices of the columns of <see cref="Schema.Count"/>
91-
/// for <see cref="InputSchema"/>.
92-
/// </summary>
93-
Func<int, bool> GetDependencies(Func<int, bool> predicate);
94-
95-
/// <summary>
96-
/// Get an <see cref="Row"/> with the indicated active columns, based on the input <paramref name="input"/>.
97-
/// The active columns are those for which <paramref name="active"/> returns true. Getting values on inactive
98-
/// columns of the returned row will throw. Null predicates are disallowed.
99-
///
100-
/// The <see cref="Row.Schema"/> of <paramref name="input"/> should be the same object as
101-
/// <see cref="InputSchema"/>. Implementors of this method should throw if that is not the case. Conversely,
102-
/// the returned value must have the same schema as <see cref="OutputSchema"/>.
103-
///
104-
/// This method creates a live connection between the input <see cref="Row"/> and the output <see
105-
/// cref="Row"/>. In particular, when the getters of the output <see cref="Row"/> are invoked, they invoke the
106-
/// getters of the input row and base the output values on the current values of the input <see cref="Row"/>.
107-
/// The output <see cref="Row"/> values are re-computed when requested through the getters. Also, the returned
108-
/// <see cref="Row"/> will dispose <paramref name="input"/> when it is disposed.
109-
/// </summary>
110-
Row GetRow(Row input, Func<int, bool> active);
111-
}
11269
}

src/Microsoft.ML.Core/Data/RoleMappedSchema.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ public static ColumnInfo CreateFromIndex(Schema schema, int index)
100100
/// </remarks>
101101
/// <seealso cref="ColumnRole"/>
102102
/// <seealso cref="RoleMappedData"/>
103-
public sealed class RoleMappedSchema
103+
[BestFriend]
104+
internal sealed class RoleMappedSchema
104105
{
105106
private const string FeatureString = "Feature";
106107
private const string LabelString = "Label";

src/Microsoft.ML.Data/Commands/SavePredictorCommand.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ private Stream CreateStrm(IFileHandle file)
114114
}
115115
}
116116

117-
public static class SavePredictorUtils
117+
[BestFriend]
118+
internal static class SavePredictorUtils
118119
{
119120
public static void SavePredictor(IHostEnvironment env, Stream modelStream, Stream binaryModelStream = null, Stream summaryModelStream = null,
120121
Stream textModelStream = null, Stream iniModelStream = null, Stream codeModelStream = null)

src/Microsoft.ML.Data/Commands/ScoreCommand.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ public interface IDataScorerTransform : IDataTransform, ITransformTemplate
3434
/// <param name="trainSchema">This parameter holds a snapshot of the role mapped training schema as
3535
/// it existed at the point when <paramref name="mapper"/> was trained, or <c>null</c> if it not
3636
/// available for some reason</param>
37-
public delegate void SignatureDataScorer(IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema);
37+
[BestFriend]
38+
internal delegate void SignatureDataScorer(IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema);
3839

3940
public delegate void SignatureBindableMapper(IPredictor predictor);
4041

src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616

1717
namespace Microsoft.ML.Runtime.Internal.Internallearn
1818
{
19-
public abstract class FeatureNameCollection : IEnumerable<string>
19+
[BestFriend]
20+
internal abstract class FeatureNameCollection : IEnumerable<string>
2021
{
2122
private sealed class FeatureNameCollectionSchema : ISchema
2223
{

src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,8 @@ internal interface ICanGetSummaryAsIRow
143143
Row GetStatsIRowOrNull(RoleMappedSchema schema);
144144
}
145145

146-
public interface ICanGetSummaryAsIDataView
146+
[BestFriend]
147+
internal interface ICanGetSummaryAsIDataView
147148
{
148149
IDataView GetSummaryDataView(RoleMappedSchema schema);
149150
}

src/Microsoft.ML.Data/Dirty/PredictorUtils.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99

1010
namespace Microsoft.ML.Runtime.Internal.Internallearn
1111
{
12-
public static class PredictorUtils
12+
[BestFriend]
13+
internal static class PredictorUtils
1314
{
1415
/// <summary>
15-
/// Save the model summary
16+
/// Save the model summary.
1617
/// </summary>
1718
public static void SaveSummary(IChannel ch, IPredictor predictor, RoleMappedSchema schema, TextWriter writer)
1819
{

src/Microsoft.ML.Data/EntryPoints/SummarizePredictor.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ public static CommonOutputs.SummaryOutput Summarize(IHostEnvironment env, Summar
4343
return output;
4444
}
4545

46-
public static IDataView GetSummaryAndStats(IHostEnvironment env, IPredictor predictor, RoleMappedSchema schema, out IDataView stats)
46+
[BestFriend]
47+
internal static IDataView GetSummaryAndStats(IHostEnvironment env, IPredictor predictor, RoleMappedSchema schema, out IDataView stats)
4748
{
4849
var calibrated = predictor as CalibratedPredictorBase;
4950
while (calibrated != null)

src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ public AnomalyDetectionEvaluator(IHostEnvironment env, Arguments args)
9191
_aucCount = args.MaxAucExamples;
9292
}
9393

94-
protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
94+
private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
9595
{
9696
var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score);
9797
var t = score.Type;
@@ -103,7 +103,7 @@ protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
103103
throw Host.Except("Label column '{0}' has type '{1}' but must be R4 or a 2-value key", schema.Label.Name, t).MarkSensitive(MessageSensitivity.Schema);
104104
}
105105

106-
protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName)
106+
private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName)
107107
{
108108
return new Aggregator(Host, _aucCount, _numTopResults, _k, _p, _streaming, schema.Name == null ? -1 : schema.Name.Index, stratName);
109109
}
@@ -124,7 +124,7 @@ public override IEnumerable<MetricColumn> GetOverallMetricColumns()
124124
yield return new MetricColumn("ThreshAtNumPos", OverallMetrics.ThreshAtNumPos, MetricColumn.Objective.Info, canBeWeighted: false);
125125
}
126126

127-
protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries,
127+
private protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries,
128128
out Action<uint, ReadOnlyMemory<char>, Aggregator> addAgg, out Func<Dictionary<string, IDataView>> consolidate)
129129
{
130130
var stratCol = new List<uint>();
@@ -498,7 +498,7 @@ private void FinishOtherMetrics()
498498
}
499499
}
500500

501-
public override void InitializeNextPass(Row row, RoleMappedSchema schema)
501+
internal override void InitializeNextPass(Row row, RoleMappedSchema schema)
502502
{
503503
Host.Assert(!_streaming && PassNum < 2 || PassNum < 1);
504504
Host.AssertValue(schema.Label);
@@ -621,7 +621,7 @@ public AnomalyDetectionMamlEvaluator(IHostEnvironment env, Arguments args)
621621
_evaluator = new AnomalyDetectionEvaluator(Host, evalArgs);
622622
}
623623

624-
protected override void PrintFoldResultsCore(IChannel ch, Dictionary<string, IDataView> metrics)
624+
private protected override void PrintFoldResultsCore(IChannel ch, Dictionary<string, IDataView> metrics)
625625
{
626626
IDataView top;
627627
if (!metrics.TryGetValue(AnomalyDetectionEvaluator.TopKResults, out top))
@@ -732,7 +732,7 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary<string, IDa
732732
ch.Info(MetricWriter.GetPerFoldResults(Host, fold, out weightedFold));
733733
}
734734

735-
protected override IDataView GetOverallResultsCore(IDataView overall)
735+
private protected override IDataView GetOverallResultsCore(IDataView overall)
736736
{
737737
return ColumnSelectingTransformer.CreateDrop(Host,
738738
overall,
@@ -742,7 +742,7 @@ protected override IDataView GetOverallResultsCore(IDataView overall)
742742
AnomalyDetectionEvaluator.OverallMetrics.ThreshAtNumPos);
743743
}
744744

745-
protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleMappedSchema schema)
745+
private protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleMappedSchema schema)
746746
{
747747
Host.CheckValue(schema, nameof(schema));
748748
Host.CheckValue(schema.Label, nameof(schema), "Data must contain a label column");

src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ public BinaryClassifierEvaluator(IHostEnvironment env, Arguments args)
123123
_auPrcCount = args.NumAuPrcExamples;
124124
}
125125

126-
protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
126+
private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
127127
{
128128
var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score);
129129
var host = Host.SchemaSensitive();
@@ -136,7 +136,7 @@ protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
136136
throw host.SchemaSensitive().Except("Label column '{0}' has type '{1}' but must be R4, R8, BL or a 2-value key", schema.Label.Name, t);
137137
}
138138

139-
protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema)
139+
private protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema)
140140
{
141141
var prob = schema.GetColumns(MetadataUtils.Const.ScoreValueKind.Probability);
142142
var host = Host.SchemaSensitive();
@@ -155,15 +155,15 @@ protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema)
155155
}
156156

157157
// Add also the probability column.
158-
protected override Func<int, bool> GetActiveColsCore(RoleMappedSchema schema)
158+
private protected override Func<int, bool> GetActiveColsCore(RoleMappedSchema schema)
159159
{
160160
var pred = base.GetActiveColsCore(schema);
161161
var prob = schema.GetColumns(MetadataUtils.Const.ScoreValueKind.Probability);
162162
Host.Assert(prob == null || prob.Count == 1);
163163
return i => Utils.Size(prob) > 0 && i == prob[0].Index || pred(i);
164164
}
165165

166-
protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName)
166+
private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName)
167167
{
168168
var classNames = GetClassNames(schema);
169169
return new Aggregator(Host, classNames, schema.Weight != null, _aucCount, _auPrcCount, _threshold, _useRaw, _prCount, stratName);
@@ -215,7 +215,7 @@ public override IEnumerable<MetricColumn> GetOverallMetricColumns()
215215
yield return new MetricColumn("AUPRC", AuPrc);
216216
}
217217

218-
protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries,
218+
private protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries,
219219
out Action<uint, ReadOnlyMemory<char>, Aggregator> addAgg, out Func<Dictionary<string, IDataView>> consolidate)
220220
{
221221
var stratCol = new List<uint>();
@@ -609,7 +609,7 @@ public Aggregator(IHostEnvironment env, ReadOnlyMemory<char>[] classNames, bool
609609
}
610610
}
611611

612-
public override void InitializeNextPass(Row row, RoleMappedSchema schema)
612+
internal override void InitializeNextPass(Row row, RoleMappedSchema schema)
613613
{
614614
Host.AssertValue(schema.Label);
615615
Host.Assert(PassNum < 1);
@@ -1172,7 +1172,7 @@ public BinaryClassifierMamlEvaluator(IHostEnvironment env, Arguments args)
11721172
_evaluator = new BinaryClassifierEvaluator(Host, evalArgs);
11731173
}
11741174

1175-
protected override IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> GetInputColumnRolesCore(RoleMappedSchema schema)
1175+
private protected override IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> GetInputColumnRolesCore(RoleMappedSchema schema)
11761176
{
11771177
var cols = base.GetInputColumnRolesCore(schema);
11781178

@@ -1187,7 +1187,7 @@ public BinaryClassifierMamlEvaluator(IHostEnvironment env, Arguments args)
11871187
return cols;
11881188
}
11891189

1190-
protected override void PrintFoldResultsCore(IChannel ch, Dictionary<string, IDataView> metrics)
1190+
private protected override void PrintFoldResultsCore(IChannel ch, Dictionary<string, IDataView> metrics)
11911191
{
11921192
ch.AssertValue(metrics);
11931193

@@ -1240,12 +1240,12 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary<string, IDa
12401240
ch.Info(MessageSensitivity.None, unweightedFold);
12411241
}
12421242

1243-
protected override IDataView GetOverallResultsCore(IDataView overall)
1243+
private protected override IDataView GetOverallResultsCore(IDataView overall)
12441244
{
12451245
return ColumnSelectingTransformer.CreateDrop(Host, overall, BinaryClassifierEvaluator.Entropy);
12461246
}
12471247

1248-
protected override void PrintAdditionalMetricsCore(IChannel ch, Dictionary<string, IDataView>[] metrics)
1248+
private protected override void PrintAdditionalMetricsCore(IChannel ch, Dictionary<string, IDataView>[] metrics)
12491249
{
12501250
ch.AssertNonEmpty(metrics);
12511251

@@ -1479,7 +1479,7 @@ private void SavePrPlots(List<IDataView> prList)
14791479
return avgPoints;
14801480
}
14811481
#endif
1482-
protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleMappedSchema schema)
1482+
private protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleMappedSchema schema)
14831483
{
14841484
Host.CheckValue(schema, nameof(schema));
14851485
Host.CheckParam(schema.Label != null, nameof(schema), "Schema must contain a label column");

0 commit comments

Comments
 (0)