Skip to content

Internalize RoleMappedSchema and implications thereof #1902

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 18, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions src/Microsoft.ML.Core/Data/IRowToRowMapper.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Data;
using System;

namespace Microsoft.ML.Runtime.Data
{
/// <summary>
/// This interface maps an input <see cref="Row"/> to an output <see cref="Row"/>. Typically, the output contains
/// both the input columns and new columns added by the implementing class, although some implementations may
/// return a subset of the input columns.
/// This interface is similar to <see cref="ISchemaBoundRowMapper"/>, except it does not have any input role mappings,
/// so to rebind, the same input column names must be used.
/// Implementations of this interface are typically created over defined input <see cref="Schema"/>.
/// </summary>
public interface IRowToRowMapper
{
/// <summary>
/// Mappers are defined as accepting inputs with this very specific schema.
/// </summary>
Schema InputSchema { get; }

/// <summary>
/// Gets an instance of <see cref="Schema"/> which describes the columns' names and types in the output generated by this mapper.
/// </summary>
Schema OutputSchema { get; }

/// <summary>
/// Given a predicate specifying which columns are needed, return a predicate indicating which input columns are
/// needed. The domain of the function is defined over the indices of the columns of <see cref="Schema.Count"/>
/// for <see cref="InputSchema"/>.
/// </summary>
Func<int, bool> GetDependencies(Func<int, bool> predicate);

/// <summary>
/// Get an <see cref="Row"/> with the indicated active columns, based on the input <paramref name="input"/>.
/// The active columns are those for which <paramref name="active"/> returns true. Getting values on inactive
/// columns of the returned row will throw. Null predicates are disallowed.
///
/// The <see cref="Row.Schema"/> of <paramref name="input"/> should be the same object as
/// <see cref="InputSchema"/>. Implementors of this method should throw if that is not the case. Conversely,
/// the returned value must have the same schema as <see cref="OutputSchema"/>.
///
/// This method creates a live connection between the input <see cref="Row"/> and the output <see
/// cref="Row"/>. In particular, when the getters of the output <see cref="Row"/> are invoked, they invoke the
/// getters of the input row and base the output values on the current values of the input <see cref="Row"/>.
/// The output <see cref="Row"/> values are re-computed when requested through the getters. Also, the returned
/// <see cref="Row"/> will dispose <paramref name="input"/> when it is disposed.
/// </summary>
Row GetRow(Row input, Func<int, bool> active);
}
}
55 changes: 6 additions & 49 deletions src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Data;
using System;
using System.Collections.Generic;

namespace Microsoft.ML.Runtime.Data
Expand All @@ -21,7 +20,8 @@ namespace Microsoft.ML.Runtime.Data
/// for the output schema of the <see cref="ISchemaBoundMapper"/>. In case the <see cref="ISchemaBoundRowMapper"/> interface is implemented,
/// the SimpleRow class can be used in the <see cref="IRowToRowMapper.GetRow"/> method.
/// </summary>
public interface ISchemaBindableMapper
[BestFriend]
internal interface ISchemaBindableMapper
{
ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema);
}
Expand All @@ -30,7 +30,8 @@ public interface ISchemaBindableMapper
/// This interface is used to map a schema from input columns to output columns. The <see cref="ISchemaBoundMapper"/> should keep track
/// of the input columns that are needed for the mapping.
/// </summary>
public interface ISchemaBoundMapper
[BestFriend]
internal interface ISchemaBoundMapper
{
/// <summary>
/// The <see cref="RoleMappedSchema"/> that was passed to the <see cref="ISchemaBoundMapper"/> in the binding process.
Expand All @@ -56,57 +57,13 @@ public interface ISchemaBoundMapper
/// <summary>
/// This interface combines <see cref="ISchemaBoundMapper"/> with <see cref="IRowToRowMapper"/>.
/// </summary>
public interface ISchemaBoundRowMapper : ISchemaBoundMapper, IRowToRowMapper
[BestFriend]
internal interface ISchemaBoundRowMapper : ISchemaBoundMapper, IRowToRowMapper
{
/// <summary>
/// There are two schemas from <see cref="ISchemaBoundMapper"/> and <see cref="IRowToRowMapper"/>.
/// Since the two parent schema's are identical in all derived classes, we merge them into <see cref="OutputSchema"/>.
/// </summary>
new Schema OutputSchema { get; }
}

/// <summary>
/// This interface maps an input <see cref="Row"/> to an output <see cref="Row"/>. Typically, the output contains
/// both the input columns and new columns added by the implementing class, although some implementations may
/// return a subset of the input columns.
/// This interface is similar to <see cref="ISchemaBoundRowMapper"/>, except it does not have any input role mappings,
/// so to rebind, the same input column names must be used.
/// Implementations of this interface are typically created over defined input <see cref="Schema"/>.
/// </summary>
public interface IRowToRowMapper
{
/// <summary>
/// Mappers are defined as accepting inputs with this very specific schema.
/// </summary>
Schema InputSchema { get; }

/// <summary>
/// Gets an instance of <see cref="Schema"/> which describes the columns' names and types in the output generated by this mapper.
/// </summary>
Schema OutputSchema { get; }

/// <summary>
/// Given a predicate specifying which columns are needed, return a predicate indicating which input columns are
/// needed. The domain of the function is defined over the indices of the columns of <see cref="Schema.Count"/>
/// for <see cref="InputSchema"/>.
/// </summary>
Func<int, bool> GetDependencies(Func<int, bool> predicate);

/// <summary>
/// Get an <see cref="Row"/> with the indicated active columns, based on the input <paramref name="input"/>.
/// The active columns are those for which <paramref name="active"/> returns true. Getting values on inactive
/// columns of the returned row will throw. Null predicates are disallowed.
///
/// The <see cref="Row.Schema"/> of <paramref name="input"/> should be the same object as
/// <see cref="InputSchema"/>. Implementors of this method should throw if that is not the case. Conversely,
/// the returned value must have the same schema as <see cref="OutputSchema"/>.
///
/// This method creates a live connection between the input <see cref="Row"/> and the output <see
/// cref="Row"/>. In particular, when the getters of the output <see cref="Row"/> are invoked, they invoke the
/// getters of the input row and base the output values on the current values of the input <see cref="Row"/>.
/// The output <see cref="Row"/> values are re-computed when requested through the getters. Also, the returned
/// <see cref="Row"/> will dispose <paramref name="input"/> when it is disposed.
/// </summary>
Row GetRow(Row input, Func<int, bool> active);
}
}
3 changes: 2 additions & 1 deletion src/Microsoft.ML.Core/Data/RoleMappedSchema.cs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ public static ColumnInfo CreateFromIndex(Schema schema, int index)
/// </remarks>
/// <seealso cref="ColumnRole"/>
/// <seealso cref="RoleMappedData"/>
public sealed class RoleMappedSchema
[BestFriend]
internal sealed class RoleMappedSchema
{
private const string FeatureString = "Feature";
private const string LabelString = "Label";
Expand Down
3 changes: 2 additions & 1 deletion src/Microsoft.ML.Data/Commands/SavePredictorCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ private Stream CreateStrm(IFileHandle file)
}
}

public static class SavePredictorUtils
[BestFriend]
internal static class SavePredictorUtils
{
public static void SavePredictor(IHostEnvironment env, Stream modelStream, Stream binaryModelStream = null, Stream summaryModelStream = null,
Stream textModelStream = null, Stream iniModelStream = null, Stream codeModelStream = null)
Expand Down
3 changes: 2 additions & 1 deletion src/Microsoft.ML.Data/Commands/ScoreCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ public interface IDataScorerTransform : IDataTransform, ITransformTemplate
/// <param name="trainSchema">This parameter holds a snapshot of the role mapped training schema as
/// it existed at the point when <paramref name="mapper"/> was trained, or <c>null</c> if it not
/// available for some reason</param>
public delegate void SignatureDataScorer(IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema);
[BestFriend]
internal delegate void SignatureDataScorer(IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema);

public delegate void SignatureBindableMapper(IPredictor predictor);

Expand Down
3 changes: 2 additions & 1 deletion src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

namespace Microsoft.ML.Runtime.Internal.Internallearn
{
public abstract class FeatureNameCollection : IEnumerable<string>
[BestFriend]
internal abstract class FeatureNameCollection : IEnumerable<string>
{
private sealed class FeatureNameCollectionSchema : ISchema
{
Expand Down
3 changes: 2 additions & 1 deletion src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ internal interface ICanGetSummaryAsIRow
Row GetStatsIRowOrNull(RoleMappedSchema schema);
}

public interface ICanGetSummaryAsIDataView
[BestFriend]
internal interface ICanGetSummaryAsIDataView
{
IDataView GetSummaryDataView(RoleMappedSchema schema);
}
Expand Down
5 changes: 3 additions & 2 deletions src/Microsoft.ML.Data/Dirty/PredictorUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@

namespace Microsoft.ML.Runtime.Internal.Internallearn
{
public static class PredictorUtils
[BestFriend]
internal static class PredictorUtils
{
/// <summary>
/// Save the model summary
/// Save the model summary.
/// </summary>
public static void SaveSummary(IChannel ch, IPredictor predictor, RoleMappedSchema schema, TextWriter writer)
{
Expand Down
3 changes: 2 additions & 1 deletion src/Microsoft.ML.Data/EntryPoints/SummarizePredictor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ public static CommonOutputs.SummaryOutput Summarize(IHostEnvironment env, Summar
return output;
}

public static IDataView GetSummaryAndStats(IHostEnvironment env, IPredictor predictor, RoleMappedSchema schema, out IDataView stats)
[BestFriend]
internal static IDataView GetSummaryAndStats(IHostEnvironment env, IPredictor predictor, RoleMappedSchema schema, out IDataView stats)
{
var calibrated = predictor as CalibratedPredictorBase;
while (calibrated != null)
Expand Down
14 changes: 7 additions & 7 deletions src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public AnomalyDetectionEvaluator(IHostEnvironment env, Arguments args)
_aucCount = args.MaxAucExamples;
}

protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
{
var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score);
var t = score.Type;
Expand All @@ -103,7 +103,7 @@ protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
throw Host.Except("Label column '{0}' has type '{1}' but must be R4 or a 2-value key", schema.Label.Name, t).MarkSensitive(MessageSensitivity.Schema);
}

protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName)
private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName)
{
return new Aggregator(Host, _aucCount, _numTopResults, _k, _p, _streaming, schema.Name == null ? -1 : schema.Name.Index, stratName);
}
Expand All @@ -124,7 +124,7 @@ public override IEnumerable<MetricColumn> GetOverallMetricColumns()
yield return new MetricColumn("ThreshAtNumPos", OverallMetrics.ThreshAtNumPos, MetricColumn.Objective.Info, canBeWeighted: false);
}

protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries,
private protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries,
out Action<uint, ReadOnlyMemory<char>, Aggregator> addAgg, out Func<Dictionary<string, IDataView>> consolidate)
{
var stratCol = new List<uint>();
Expand Down Expand Up @@ -498,7 +498,7 @@ private void FinishOtherMetrics()
}
}

public override void InitializeNextPass(Row row, RoleMappedSchema schema)
internal override void InitializeNextPass(Row row, RoleMappedSchema schema)
{
Host.Assert(!_streaming && PassNum < 2 || PassNum < 1);
Host.AssertValue(schema.Label);
Expand Down Expand Up @@ -621,7 +621,7 @@ public AnomalyDetectionMamlEvaluator(IHostEnvironment env, Arguments args)
_evaluator = new AnomalyDetectionEvaluator(Host, evalArgs);
}

protected override void PrintFoldResultsCore(IChannel ch, Dictionary<string, IDataView> metrics)
private protected override void PrintFoldResultsCore(IChannel ch, Dictionary<string, IDataView> metrics)
{
IDataView top;
if (!metrics.TryGetValue(AnomalyDetectionEvaluator.TopKResults, out top))
Expand Down Expand Up @@ -732,7 +732,7 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary<string, IDa
ch.Info(MetricWriter.GetPerFoldResults(Host, fold, out weightedFold));
}

protected override IDataView GetOverallResultsCore(IDataView overall)
private protected override IDataView GetOverallResultsCore(IDataView overall)
{
return ColumnSelectingTransformer.CreateDrop(Host,
overall,
Expand All @@ -742,7 +742,7 @@ protected override IDataView GetOverallResultsCore(IDataView overall)
AnomalyDetectionEvaluator.OverallMetrics.ThreshAtNumPos);
}

protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleMappedSchema schema)
private protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleMappedSchema schema)
{
Host.CheckValue(schema, nameof(schema));
Host.CheckValue(schema.Label, nameof(schema), "Data must contain a label column");
Expand Down
22 changes: 11 additions & 11 deletions src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ public BinaryClassifierEvaluator(IHostEnvironment env, Arguments args)
_auPrcCount = args.NumAuPrcExamples;
}

protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
{
var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score);
var host = Host.SchemaSensitive();
Expand All @@ -136,7 +136,7 @@ protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
throw host.SchemaSensitive().Except("Label column '{0}' has type '{1}' but must be R4, R8, BL or a 2-value key", schema.Label.Name, t);
}

protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema)
private protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema)
{
var prob = schema.GetColumns(MetadataUtils.Const.ScoreValueKind.Probability);
var host = Host.SchemaSensitive();
Expand All @@ -155,15 +155,15 @@ protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema)
}

// Add also the probability column.
protected override Func<int, bool> GetActiveColsCore(RoleMappedSchema schema)
private protected override Func<int, bool> GetActiveColsCore(RoleMappedSchema schema)
{
var pred = base.GetActiveColsCore(schema);
var prob = schema.GetColumns(MetadataUtils.Const.ScoreValueKind.Probability);
Host.Assert(prob == null || prob.Count == 1);
return i => Utils.Size(prob) > 0 && i == prob[0].Index || pred(i);
}

protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName)
private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName)
{
var classNames = GetClassNames(schema);
return new Aggregator(Host, classNames, schema.Weight != null, _aucCount, _auPrcCount, _threshold, _useRaw, _prCount, stratName);
Expand Down Expand Up @@ -215,7 +215,7 @@ public override IEnumerable<MetricColumn> GetOverallMetricColumns()
yield return new MetricColumn("AUPRC", AuPrc);
}

protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries,
private protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries,
out Action<uint, ReadOnlyMemory<char>, Aggregator> addAgg, out Func<Dictionary<string, IDataView>> consolidate)
{
var stratCol = new List<uint>();
Expand Down Expand Up @@ -609,7 +609,7 @@ public Aggregator(IHostEnvironment env, ReadOnlyMemory<char>[] classNames, bool
}
}

public override void InitializeNextPass(Row row, RoleMappedSchema schema)
internal override void InitializeNextPass(Row row, RoleMappedSchema schema)
{
Host.AssertValue(schema.Label);
Host.Assert(PassNum < 1);
Expand Down Expand Up @@ -1172,7 +1172,7 @@ public BinaryClassifierMamlEvaluator(IHostEnvironment env, Arguments args)
_evaluator = new BinaryClassifierEvaluator(Host, evalArgs);
}

protected override IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> GetInputColumnRolesCore(RoleMappedSchema schema)
private protected override IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> GetInputColumnRolesCore(RoleMappedSchema schema)
{
var cols = base.GetInputColumnRolesCore(schema);

Expand All @@ -1187,7 +1187,7 @@ public BinaryClassifierMamlEvaluator(IHostEnvironment env, Arguments args)
return cols;
}

protected override void PrintFoldResultsCore(IChannel ch, Dictionary<string, IDataView> metrics)
private protected override void PrintFoldResultsCore(IChannel ch, Dictionary<string, IDataView> metrics)
{
ch.AssertValue(metrics);

Expand Down Expand Up @@ -1240,12 +1240,12 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary<string, IDa
ch.Info(MessageSensitivity.None, unweightedFold);
}

protected override IDataView GetOverallResultsCore(IDataView overall)
private protected override IDataView GetOverallResultsCore(IDataView overall)
{
return ColumnSelectingTransformer.CreateDrop(Host, overall, BinaryClassifierEvaluator.Entropy);
}

protected override void PrintAdditionalMetricsCore(IChannel ch, Dictionary<string, IDataView>[] metrics)
private protected override void PrintAdditionalMetricsCore(IChannel ch, Dictionary<string, IDataView>[] metrics)
{
ch.AssertNonEmpty(metrics);

Expand Down Expand Up @@ -1479,7 +1479,7 @@ private void SavePrPlots(List<IDataView> prList)
return avgPoints;
}
#endif
protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleMappedSchema schema)
private protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleMappedSchema schema)
{
Host.CheckValue(schema, nameof(schema));
Host.CheckParam(schema.Label != null, nameof(schema), "Schema must contain a label column");
Expand Down
Loading