diff --git a/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs b/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs index 0014ae638e..2d76cf4caa 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs @@ -25,6 +25,7 @@ public sealed class TransformWrapper : ITransformer private readonly IHost _host; private readonly IDataView _xf; private readonly bool _allowSave; + private readonly bool _isRowToRowMapper; public TransformWrapper(IHostEnvironment env, IDataView xf, bool allowSave = false) { @@ -33,7 +34,7 @@ public TransformWrapper(IHostEnvironment env, IDataView xf, bool allowSave = fal _host.CheckValue(xf, nameof(xf)); _xf = xf; _allowSave = allowSave; - IsRowToRowMapper = IsChainRowToRowMapper(_xf); + _isRowToRowMapper = IsChainRowToRowMapper(_xf); } public DataViewSchema GetOutputSchema(DataViewSchema inputSchema) @@ -108,7 +109,7 @@ private TransformWrapper(IHostEnvironment env, ModelLoadContext ctx) } _xf = data; - IsRowToRowMapper = IsChainRowToRowMapper(_xf); + _isRowToRowMapper = IsChainRowToRowMapper(_xf); } public IDataView Transform(IDataView input) => ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, input); @@ -123,9 +124,9 @@ private static bool IsChainRowToRowMapper(IDataView view) return true; } - public bool IsRowToRowMapper { get; } + bool ITransformer.IsRowToRowMapper => _isRowToRowMapper; - public IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema) + IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); var input = new EmptyDataView(_host, inputSchema); diff --git a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs index 30202c0617..d776b10f19 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs @@ -59,7 +59,7 @@ public sealed class TransformerChain : ITransformer, IEnumerab private const string TransformDirTemplate = "Transform_{0:000}"; - public bool IsRowToRowMapper => _transformers.All(t => t.IsRowToRowMapper); + bool ITransformer.IsRowToRowMapper => _transformers.All(t => t.IsRowToRowMapper); ITransformer[] ITransformerChainAccessor.Transformers => _transformers; @@ -216,10 +216,11 @@ public void SaveTo(IHostEnvironment env, Stream outputStream) IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); - public IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema) + IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema) { Contracts.CheckValue(inputSchema, nameof(inputSchema)); - Contracts.Check(IsRowToRowMapper, nameof(GetRowToRowMapper) + " method called despite " + nameof(IsRowToRowMapper) + " being false."); + Contracts.Check(((ITransformer)this).IsRowToRowMapper, nameof(ITransformer.GetRowToRowMapper) + " method called despite " + + nameof(ITransformer.IsRowToRowMapper) + " being false."); IRowToRowMapper[] mappers = new IRowToRowMapper[_transformers.Length]; DataViewSchema schema = inputSchema; diff --git a/src/Microsoft.ML.Data/Prediction/PredictionEngine.cs b/src/Microsoft.ML.Data/Prediction/PredictionEngine.cs index 6b396180c3..2100ef3fd0 100644 --- a/src/Microsoft.ML.Data/Prediction/PredictionEngine.cs +++ b/src/Microsoft.ML.Data/Prediction/PredictionEngine.cs @@ -116,8 +116,8 @@ private static Func StreamChecker(IHostEnvironm { var pipe = DataViewConstructionUtils.LoadPipeWithPredictor(env, modelStream, new EmptyDataView(env, schema)); var transformer = new TransformWrapper(env, pipe); - env.CheckParam(transformer.IsRowToRowMapper, nameof(transformer), "Must be a row to row mapper"); - return transformer.GetRowToRowMapper(schema); + env.CheckParam(((ITransformer)transformer).IsRowToRowMapper, nameof(transformer), "Must be a row to row mapper"); + return ((ITransformer)transformer).GetRowToRowMapper(schema); }; } diff --git a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs index 8bab13fa5d..7b94c92795 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs @@ -55,10 +55,10 @@ public abstract class PredictionTransformerBase : IPredictionTransformer protected DataViewSchema TrainSchema; /// - /// Whether a call to should succeed, on an + /// Whether a call to should succeed, on an /// appropriate schema. /// - public bool IsRowToRowMapper => true; + bool ITransformer.IsRowToRowMapper => true; /// /// This class is more or less a thin wrapper over the implementing @@ -132,7 +132,7 @@ public IDataView Transform(IDataView input) /// /// /// - public IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema) + IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); return (IRowToRowMapper)Scorer.ApplyToData(Host, new EmptyDataView(Host, inputSchema)); diff --git a/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs b/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs index 2a0c0686d2..2e8354e2f1 100644 --- a/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs +++ b/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs @@ -138,7 +138,7 @@ public sealed class ColumnSelectingTransformer : ITransformer private readonly IHost _host; private string[] _selectedColumns; - public bool IsRowToRowMapper => true; + bool ITransformer.IsRowToRowMapper => true; public IEnumerable SelectColumns => _selectedColumns.AsReadOnly(); @@ -458,13 +458,13 @@ public DataViewSchema GetOutputSchema(DataViewSchema inputSchema) } /// - /// Constructs a row-to-row mapper based on an input schema. If + /// Constructs a row-to-row mapper based on an input schema. If /// is false, then an exception is thrown. If the input schema is in any way /// unsuitable for constructing the mapper, an exception should likewise be thrown. /// /// The input schema for which we should get the mapper. /// The row to row mapper. - public IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema) + IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); if (!IgnoreMissing && !IsSchemaValid(inputSchema.Select(x => x.Name), diff --git a/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs b/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs index e3166d7407..78a34ad8bf 100644 --- a/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs +++ b/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs @@ -26,9 +26,9 @@ protected RowToRowTransformerBase(IHost host) private protected abstract void SaveModel(ModelSaveContext ctx); - public bool IsRowToRowMapper => true; + bool ITransformer.IsRowToRowMapper => true; - public IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema) + IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); return new RowToRowMapperTransform(Host, new EmptyDataView(Host, inputSchema), MakeRowMapper(inputSchema), MakeRowMapper); diff --git a/src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs b/src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs index 38cca21a49..46902ae369 100644 --- a/src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs +++ b/src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs @@ -18,10 +18,10 @@ namespace Microsoft.ML.Transforms.TimeSeries public class IidAnomalyDetectionBaseWrapper : IStatefulTransformer, ICanSaveModel { /// - /// Whether a call to should succeed, on an + /// Whether a call to should succeed, on an /// appropriate schema. /// - public bool IsRowToRowMapper => InternalTransform.IsRowToRowMapper; + bool ITransformer.IsRowToRowMapper => ((ITransformer)InternalTransform).IsRowToRowMapper; /// /// Creates a clone of the transfomer. Used for taking the snapshot of the state. @@ -36,20 +36,22 @@ public class IidAnomalyDetectionBaseWrapper : IStatefulTransformer, ICanSaveMode public DataViewSchema GetOutputSchema(DataViewSchema inputSchema) => InternalTransform.GetOutputSchema(inputSchema); /// - /// Constructs a row-to-row mapper based on an input schema. If + /// Constructs a row-to-row mapper based on an input schema. If /// is false, then an exception should be thrown. If the input schema is in any way /// unsuitable for constructing the mapper, an exception should likewise be thrown. /// /// The input schema for which we should get the mapper. /// The row to row mapper. - public IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema) => InternalTransform.GetRowToRowMapper(inputSchema); + IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema) + => ((ITransformer)InternalTransform).GetRowToRowMapper(inputSchema); /// /// Same as but also supports mechanism to save the state. /// /// The input schema for which we should get the mapper. /// The row to row mapper. - public IRowToRowMapper GetStatefulRowToRowMapper(DataViewSchema inputSchema) => ((IStatefulTransformer)InternalTransform).GetStatefulRowToRowMapper(inputSchema); + public IRowToRowMapper GetStatefulRowToRowMapper(DataViewSchema inputSchema) + => ((IStatefulTransformer)InternalTransform).GetStatefulRowToRowMapper(inputSchema); /// /// Take the data in, make transformations, output the data. @@ -60,7 +62,9 @@ public class IidAnomalyDetectionBaseWrapper : IStatefulTransformer, ICanSaveMode /// /// For saving a model into a repository. /// - public virtual void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx); + + private protected virtual void SaveModel(ModelSaveContext ctx) { InternalTransform.SaveThis(ctx); } @@ -129,7 +133,7 @@ public override DataViewSchema GetOutputSchema(DataViewSchema inputSchema) private protected override void SaveModel(ModelSaveContext ctx) { - Parent.Save(ctx); + ((ICanSaveModel)Parent).Save(ctx); } internal void SaveThis(ModelSaveContext ctx) diff --git a/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs b/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs index da8749e18e..2b33c640b7 100644 --- a/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs +++ b/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs @@ -172,7 +172,7 @@ private IidChangePointDetector(IHostEnvironment env, IidChangePointDetector tran { } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { InternalTransform.Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -184,7 +184,7 @@ public override void Save(ModelSaveContext ctx) // *** Binary format *** // - base.Save(ctx); + base.SaveModel(ctx); } // Factory method for SignatureLoadRowMapper. diff --git a/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs b/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs index ce1091ed53..e05e990bc6 100644 --- a/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs +++ b/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs @@ -153,7 +153,7 @@ private IidSpikeDetector(IHostEnvironment env, IidSpikeDetector transform) { } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { InternalTransform.Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -164,7 +164,7 @@ public override void Save(ModelSaveContext ctx) // *** Binary format *** // - base.Save(ctx); + base.SaveModel(ctx); } // Factory method for SignatureLoadRowMapper. diff --git a/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs b/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs index 145235a1cb..15c9828c5d 100644 --- a/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs +++ b/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs @@ -348,7 +348,7 @@ public Func GetDependencies(Func activeOutput) return col => false; } - public void Save(ModelSaveContext ctx) => _parent.SaveModel(ctx); + void ICanSaveModel.Save(ModelSaveContext ctx) => _parent.SaveModel(ctx); public Delegate[] CreateGetters(DataViewRow input, Func activeOutput, out Action disposer) { diff --git a/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs b/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs index 80824a6895..0626216900 100644 --- a/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs +++ b/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs @@ -269,7 +269,7 @@ private protected virtual void CloneCore(TState state) internal readonly string OutputColumnName; private protected DataViewType OutputColumnType; - public bool IsRowToRowMapper => false; + bool ITransformer.IsRowToRowMapper => false; internal TState StateRef { get; set; } diff --git a/src/Microsoft.ML.TimeSeries/SsaAnomalyDetectionBase.cs b/src/Microsoft.ML.TimeSeries/SsaAnomalyDetectionBase.cs index 8241602399..704002298d 100644 --- a/src/Microsoft.ML.TimeSeries/SsaAnomalyDetectionBase.cs +++ b/src/Microsoft.ML.TimeSeries/SsaAnomalyDetectionBase.cs @@ -87,10 +87,10 @@ public static Func GetErrorFunction(ErrorFunction errorF public class SsaAnomalyDetectionBaseWrapper : IStatefulTransformer, ICanSaveModel { /// - /// Whether a call to should succeed, on an + /// Whether a call to should succeed, on an /// appropriate schema. /// - public bool IsRowToRowMapper => InternalTransform.IsRowToRowMapper; + bool ITransformer.IsRowToRowMapper => ((ITransformer)InternalTransform).IsRowToRowMapper; /// /// Creates a clone of the transfomer. Used for taking the snapshot of the state. @@ -105,20 +105,22 @@ public class SsaAnomalyDetectionBaseWrapper : IStatefulTransformer, ICanSaveMode public DataViewSchema GetOutputSchema(DataViewSchema inputSchema) => InternalTransform.GetOutputSchema(inputSchema); /// - /// Constructs a row-to-row mapper based on an input schema. If + /// Constructs a row-to-row mapper based on an input schema. If /// is false, then an exception should be thrown. If the input schema is in any way /// unsuitable for constructing the mapper, an exception should likewise be thrown. /// /// The input schema for which we should get the mapper. /// The row to row mapper. - public IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema) => InternalTransform.GetRowToRowMapper(inputSchema); + IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema) + => ((ITransformer)InternalTransform).GetRowToRowMapper(inputSchema); /// /// Same as but also supports mechanism to save the state. /// /// The input schema for which we should get the mapper. /// The row to row mapper. - public IRowToRowMapper GetStatefulRowToRowMapper(DataViewSchema inputSchema) => ((IStatefulTransformer)InternalTransform).GetStatefulRowToRowMapper(inputSchema); + public IRowToRowMapper GetStatefulRowToRowMapper(DataViewSchema inputSchema) + => ((IStatefulTransformer)InternalTransform).GetStatefulRowToRowMapper(inputSchema); /// /// Take the data in, make transformations, output the data. @@ -129,7 +131,9 @@ public class SsaAnomalyDetectionBaseWrapper : IStatefulTransformer, ICanSaveMode /// /// For saving a model into a repository. /// - public virtual void Save(ModelSaveContext ctx) => InternalTransform.SaveThis(ctx); + void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx); + + private protected virtual void SaveModel(ModelSaveContext ctx) => InternalTransform.SaveThis(ctx); /// /// Creates a row mapper from Schema. @@ -255,7 +259,7 @@ public override DataViewSchema GetOutputSchema(DataViewSchema inputSchema) private protected override void SaveModel(ModelSaveContext ctx) { - Parent.Save(ctx); + ((ICanSaveModel)Parent).Save(ctx); } internal void SaveThis(ModelSaveContext ctx) diff --git a/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs b/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs index 343431db74..ac7dee6bc1 100644 --- a/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs +++ b/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs @@ -180,7 +180,7 @@ internal SsaChangePointDetector(IHostEnvironment env, ModelLoadContext ctx) InternalTransform.Host.CheckDecode(InternalTransform.IsAdaptive == false); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { InternalTransform.Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -194,7 +194,7 @@ public override void Save(ModelSaveContext ctx) // *** Binary format *** // - base.Save(ctx); + base.SaveModel(ctx); } // Factory method for SignatureLoadRowMapper. diff --git a/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs b/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs index a35afcb94d..1b84019385 100644 --- a/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs +++ b/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs @@ -162,7 +162,7 @@ internal SsaSpikeDetector(IHostEnvironment env, ModelLoadContext ctx) InternalTransform.Host.CheckDecode(InternalTransform.IsAdaptive == false); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { InternalTransform.Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -175,7 +175,7 @@ public override void Save(ModelSaveContext ctx) // *** Binary format *** // - base.Save(ctx); + base.SaveModel(ctx); } // Factory method for SignatureLoadRowMapper. diff --git a/src/Microsoft.ML.Transforms/CustomMappingTransformer.cs b/src/Microsoft.ML.Transforms/CustomMappingTransformer.cs index b641ebf3e8..2ab2320334 100644 --- a/src/Microsoft.ML.Transforms/CustomMappingTransformer.cs +++ b/src/Microsoft.ML.Transforms/CustomMappingTransformer.cs @@ -30,10 +30,10 @@ public sealed class CustomMappingTransformer : ITransformer internal SchemaDefinition InputSchemaDefinition { get; } /// - /// Whether a call to should succeed, on an + /// Whether a call to should succeed, on an /// appropriate schema. /// - public bool IsRowToRowMapper => true; + bool ITransformer.IsRowToRowMapper => true; /// /// Create a custom mapping of input columns to output columns. @@ -95,11 +95,11 @@ public IDataView Transform(IDataView input) } /// - /// Constructs a row-to-row mapper based on an input schema. If + /// Constructs a row-to-row mapper based on an input schema. If /// is false, then an exception is thrown. If the is in any way /// unsuitable for constructing the mapper, an exception is likewise thrown. /// - public IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema) + IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); var simplerMapper = MakeRowMapper(inputSchema); diff --git a/src/Microsoft.ML.Transforms/OneHotEncoding.cs b/src/Microsoft.ML.Transforms/OneHotEncoding.cs index 17b54c27e6..19559d5b33 100644 --- a/src/Microsoft.ML.Transforms/OneHotEncoding.cs +++ b/src/Microsoft.ML.Transforms/OneHotEncoding.cs @@ -166,9 +166,9 @@ internal OneHotEncodingTransformer(ValueToKeyMappingEstimator term, IEstimator (_transformer as ICanSaveModel).Save(ctx); - public bool IsRowToRowMapper => _transformer.IsRowToRowMapper; + bool ITransformer.IsRowToRowMapper => ((ITransformer)_transformer).IsRowToRowMapper; - public IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema) => _transformer.GetRowToRowMapper(inputSchema); + IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema) => ((ITransformer)_transformer).GetRowToRowMapper(inputSchema); } /// /// Estimator which takes set of columns and produce for each column indicator array. diff --git a/src/Microsoft.ML.Transforms/OneHotHashEncoding.cs b/src/Microsoft.ML.Transforms/OneHotHashEncoding.cs index ee6e7a4ceb..5f2118f388 100644 --- a/src/Microsoft.ML.Transforms/OneHotHashEncoding.cs +++ b/src/Microsoft.ML.Transforms/OneHotHashEncoding.cs @@ -191,14 +191,14 @@ internal OneHotHashEncodingTransformer(HashingEstimator hash, IEstimator (_transformer as ICanSaveModel).Save(ctx); /// - /// Whether a call to should succeed, on an appropriate schema. + /// Whether a call to should succeed, on an appropriate schema. /// - public bool IsRowToRowMapper => _transformer.IsRowToRowMapper; + bool ITransformer.IsRowToRowMapper => ((ITransformer)_transformer).IsRowToRowMapper; /// /// Constructs a row-to-row mapper based on an input schema. /// - public IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema) => _transformer.GetRowToRowMapper(inputSchema); + IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema) => ((ITransformer)_transformer).GetRowToRowMapper(inputSchema); } /// diff --git a/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs b/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs index 3f91e73ebf..311a0b5f3a 100644 --- a/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs +++ b/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs @@ -586,9 +586,9 @@ public IDataView Transform(IDataView input) return ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, input); } - public bool IsRowToRowMapper => true; + bool ITransformer.IsRowToRowMapper => true; - public IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema) + IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); var input = new EmptyDataView(_host, inputSchema); diff --git a/test/Microsoft.ML.Benchmarks/HashBench.cs b/test/Microsoft.ML.Benchmarks/HashBench.cs index f84e32c92d..68350a901a 100644 --- a/test/Microsoft.ML.Benchmarks/HashBench.cs +++ b/test/Microsoft.ML.Benchmarks/HashBench.cs @@ -75,7 +75,7 @@ private void InitMap(T val, DataViewType type, int hashBits = 20, ValueGetter // One million features is a nice, typical number. var info = new HashingEstimator.ColumnInfo("Bar", "Foo", hashBits: hashBits); var xf = new HashingTransformer(_env, new[] { info }); - var mapper = xf.GetRowToRowMapper(_inRow.Schema); + var mapper = ((ITransformer)xf).GetRowToRowMapper(_inRow.Schema); var column = mapper.OutputSchema["Bar"]; var outRow = mapper.GetRow(_inRow, c => c == column.Index); if (type is VectorType) diff --git a/test/Microsoft.ML.Tests/Transformers/HashTests.cs b/test/Microsoft.ML.Tests/Transformers/HashTests.cs index f075742cb2..d309d8b213 100644 --- a/test/Microsoft.ML.Tests/Transformers/HashTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/HashTests.cs @@ -135,7 +135,7 @@ private void HashTestCore(T val, PrimitiveDataViewType type, uint expected, u // First do an unordered hash. var info = new HashingEstimator.ColumnInfo("Bar", "Foo", hashBits: bits); var xf = new HashingTransformer(Env, new[] { info }); - var mapper = xf.GetRowToRowMapper(inRow.Schema); + var mapper = ((ITransformer)xf).GetRowToRowMapper(inRow.Schema); mapper.OutputSchema.TryGetColumnIndex("Bar", out int outCol); var outRow = mapper.GetRow(inRow, c => c == outCol); @@ -147,7 +147,7 @@ private void HashTestCore(T val, PrimitiveDataViewType type, uint expected, u // Next do an ordered hash. info = new HashingEstimator.ColumnInfo("Bar", "Foo", hashBits: bits, ordered: true); xf = new HashingTransformer(Env, new[] { info }); - mapper = xf.GetRowToRowMapper(inRow.Schema); + mapper = ((ITransformer)xf).GetRowToRowMapper(inRow.Schema); mapper.OutputSchema.TryGetColumnIndex("Bar", out outCol); outRow = mapper.GetRow(inRow, c => c == outCol); @@ -165,7 +165,7 @@ private void HashTestCore(T val, PrimitiveDataViewType type, uint expected, u info = new HashingEstimator.ColumnInfo("Bar", "Foo", hashBits: bits, ordered: false); xf = new HashingTransformer(Env, new[] { info }); - mapper = xf.GetRowToRowMapper(inRow.Schema); + mapper = ((ITransformer)xf).GetRowToRowMapper(inRow.Schema); mapper.OutputSchema.TryGetColumnIndex("Bar", out outCol); outRow = mapper.GetRow(inRow, c => c == outCol); @@ -180,7 +180,7 @@ private void HashTestCore(T val, PrimitiveDataViewType type, uint expected, u // Now do ordered with the dense vector. info = new HashingEstimator.ColumnInfo("Bar", "Foo", hashBits: bits, ordered: true); xf = new HashingTransformer(Env, new[] { info }); - mapper = xf.GetRowToRowMapper(inRow.Schema); + mapper = ((ITransformer)xf).GetRowToRowMapper(inRow.Schema); mapper.OutputSchema.TryGetColumnIndex("Bar", out outCol); outRow = mapper.GetRow(inRow, c => c == outCol); vecGetter = outRow.GetGetter>(outCol); @@ -199,7 +199,7 @@ private void HashTestCore(T val, PrimitiveDataViewType type, uint expected, u info = new HashingEstimator.ColumnInfo("Bar", "Foo", hashBits: bits, ordered: false); xf = new HashingTransformer(Env, new[] { info }); - mapper = xf.GetRowToRowMapper(inRow.Schema); + mapper = ((ITransformer)xf).GetRowToRowMapper(inRow.Schema); mapper.OutputSchema.TryGetColumnIndex("Bar", out outCol); outRow = mapper.GetRow(inRow, c => c == outCol); vecGetter = outRow.GetGetter>(outCol); @@ -212,7 +212,7 @@ private void HashTestCore(T val, PrimitiveDataViewType type, uint expected, u info = new HashingEstimator.ColumnInfo("Bar", "Foo", hashBits: bits, ordered: true); xf = new HashingTransformer(Env, new[] { info }); - mapper = xf.GetRowToRowMapper(inRow.Schema); + mapper = ((ITransformer)xf).GetRowToRowMapper(inRow.Schema); mapper.OutputSchema.TryGetColumnIndex("Bar", out outCol); outRow = mapper.GetRow(inRow, c => c == outCol); vecGetter = outRow.GetGetter>(outCol);