diff --git a/src/Microsoft.ML.Core/Data/IEstimator.cs b/src/Microsoft.ML.Core/Data/IEstimator.cs index 6ebef55827..1e96439a1e 100644 --- a/src/Microsoft.ML.Core/Data/IEstimator.cs +++ b/src/Microsoft.ML.Core/Data/IEstimator.cs @@ -7,6 +7,7 @@ using System.Linq; using Microsoft.Data.DataView; using Microsoft.ML.Data; +using Microsoft.ML.Model; namespace Microsoft.ML.Core.Data { @@ -263,7 +264,7 @@ public interface IDataReaderEstimator /// The transformer is a component that transforms data. /// It also supports 'schema propagation' to answer the question of 'how will the data with this schema look, after you transform it?'. /// - public interface ITransformer + public interface ITransformer : ICanSaveModel { /// /// Schema propagation for transformers. diff --git a/src/Microsoft.ML.Data/Model/ModelHeader.cs b/src/Microsoft.ML.Core/Data/ModelHeader.cs similarity index 99% rename from src/Microsoft.ML.Data/Model/ModelHeader.cs rename to src/Microsoft.ML.Core/Data/ModelHeader.cs index 88fcf9bdad..d1b572e8bc 100644 --- a/src/Microsoft.ML.Data/Model/ModelHeader.cs +++ b/src/Microsoft.ML.Core/Data/ModelHeader.cs @@ -10,7 +10,8 @@ namespace Microsoft.ML.Model { - [StructLayout(LayoutKind.Explicit, Size = ModelHeader.Size)] + [BestFriend] + [StructLayout(LayoutKind.Explicit, Size = Size)] internal struct ModelHeader { /// diff --git a/src/Microsoft.ML.Data/Model/ModelLoadContext.cs b/src/Microsoft.ML.Core/Data/ModelLoadContext.cs similarity index 100% rename from src/Microsoft.ML.Data/Model/ModelLoadContext.cs rename to src/Microsoft.ML.Core/Data/ModelLoadContext.cs diff --git a/src/Microsoft.ML.Data/Model/ModelLoading.cs b/src/Microsoft.ML.Core/Data/ModelLoading.cs similarity index 98% rename from src/Microsoft.ML.Data/Model/ModelLoading.cs rename to src/Microsoft.ML.Core/Data/ModelLoading.cs index 06ebc0bf06..d561389e39 100644 --- a/src/Microsoft.ML.Data/Model/ModelLoading.cs +++ b/src/Microsoft.ML.Core/Data/ModelLoading.cs @@ -10,6 +10,12 @@ namespace Microsoft.ML.Model { + /// + /// Signature for a repository based model loader. This is the dual of . + /// + [BestFriend] + internal delegate void SignatureLoadModel(ModelLoadContext ctx); + public sealed partial class ModelLoadContext : IDisposable { public const string ModelStreamName = "Model.key"; diff --git a/src/Microsoft.ML.Data/Model/ModelSaveContext.cs b/src/Microsoft.ML.Core/Data/ModelSaveContext.cs similarity index 100% rename from src/Microsoft.ML.Data/Model/ModelSaveContext.cs rename to src/Microsoft.ML.Core/Data/ModelSaveContext.cs diff --git a/src/Microsoft.ML.Data/Model/ModelSaving.cs b/src/Microsoft.ML.Core/Data/ModelSaving.cs similarity index 100% rename from src/Microsoft.ML.Data/Model/ModelSaving.cs rename to src/Microsoft.ML.Core/Data/ModelSaving.cs diff --git a/src/Microsoft.ML.Data/Model/Repository.cs b/src/Microsoft.ML.Core/Data/Repository.cs similarity index 97% rename from src/Microsoft.ML.Data/Model/Repository.cs rename to src/Microsoft.ML.Core/Data/Repository.cs index dd3ebfe43a..42cf955b84 100644 --- a/src/Microsoft.ML.Data/Model/Repository.cs +++ b/src/Microsoft.ML.Core/Data/Repository.cs @@ -10,13 +10,11 @@ namespace Microsoft.ML.Model { - /// - /// Signature for a repository based model loader. This is the dual of ICanSaveModel. - /// - public delegate void SignatureLoadModel(ModelLoadContext ctx); - /// /// For saving a model into a repository. + /// Classes implementing should do an explicit implementation of . + /// Classes inheriting from a base class should overwrite the function invoked by + /// in that base class, if there is one. /// public interface ICanSaveModel { @@ -293,6 +291,8 @@ protected Entry AddEntry(string pathEnt, Stream stream) public sealed class RepositoryWriter : Repository { + private const string DirTrainingInfo = "TrainingInfo"; + private ZipArchive _archive; private Queue> _closed; @@ -301,7 +301,7 @@ public static RepositoryWriter CreateNew(Stream stream, IExceptionContext ectx = Contracts.CheckValueOrNull(ectx); ectx.CheckValue(stream, nameof(stream)); var rep = new RepositoryWriter(stream, ectx, useFileSystem); - using (var ent = rep.CreateEntry(ModelFileUtils.DirTrainingInfo, "Version.txt")) + using (var ent = rep.CreateEntry(DirTrainingInfo, "Version.txt")) using (var writer = Utils.OpenWriter(ent.Stream)) writer.WriteLine(typeof(RepositoryWriter).Assembly.GetName().Version); return rep; diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs index 25d00518f3..a42a84d559 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs @@ -942,7 +942,7 @@ private static Stream OpenStream(string filename) return OpenStream(files); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { _host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs index ff5348c58c..ccd25a0039 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs @@ -502,7 +502,7 @@ private static IDataLoader LoadTransforms(ModelLoadContext ctx, IDataLoader srcL }); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { _host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs index 795fa2c66e..2a8a25df2b 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs @@ -834,7 +834,7 @@ public Bindings(ModelLoadContext ctx, TextLoader parent) OutputSchema = ComputeOutputSchema(); } - public void Save(ModelSaveContext ctx) + internal void Save(ModelSaveContext ctx) { Contracts.AssertValue(ctx); @@ -1283,7 +1283,7 @@ internal static IDataLoader Create(IHostEnvironment env, Arguments args, IMultiS internal static IDataView ReadFile(IHostEnvironment env, Arguments args, IMultiStreamSource fileSource) => new TextLoader(env, args, fileSource).Read(fileSource); - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { _host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -1420,7 +1420,7 @@ public RowCursor[] GetRowCursorSet(IEnumerable columnsNeeded, int return Cursor.CreateSet(_reader, _files, active, n); } - public void Save(ModelSaveContext ctx) => _reader.Save(ctx); + void ICanSaveModel.Save(ModelSaveContext ctx) => ((ICanSaveModel)_reader).Save(ctx); } } } \ No newline at end of file diff --git a/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs b/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs index bf2af44b4c..24a29b1cc9 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs @@ -18,7 +18,7 @@ namespace Microsoft.ML.Data { // REVIEW: this class is public, as long as the Wrappers.cs in tests still rely on it. // It needs to become internal. - public sealed class TransformWrapper : ITransformer, ICanSaveModel + public sealed class TransformWrapper : ITransformer { public const string LoaderSignature = "TransformWrapper"; private const string TransformDirTemplate = "Step_{0:000}"; @@ -46,7 +46,7 @@ public Schema GetOutputSchema(Schema inputSchema) return output.Schema; } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { if (!_allowSave) throw _host.Except("Saving is not permitted."); diff --git a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs index 6b5ee01d8d..6e52f83e75 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs @@ -51,7 +51,7 @@ internal interface ITransformerChainAccessor /// A chain of transformers (possibly empty) that end with a . /// For an empty chain, is always . /// - public sealed class TransformerChain : ITransformer, ICanSaveModel, IEnumerable, ITransformerChainAccessor + public sealed class TransformerChain : ITransformer, IEnumerable, ITransformerChainAccessor where TLastTransformer : class, ITransformer { private readonly ITransformer[] _transformers; @@ -165,7 +165,7 @@ public TransformerChain Append(TNewLast transformer, Transfo return new TransformerChain(_transformers.AppendElement(transformer), _scopes.AppendElement(scope)); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); @@ -181,7 +181,7 @@ public void Save(ModelSaveContext ctx) } /// - /// The loading constructor of transformer chain. Reverse of . + /// The loading constructor of transformer chain. Reverse of . /// internal TransformerChain(IHostEnvironment env, ModelLoadContext ctx) { diff --git a/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs index 625bcf71e7..f3b7629bc4 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs @@ -513,7 +513,7 @@ public static TransposeLoader Create(IHostEnvironment env, ModelLoadContext ctx, }); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { _host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/DataView/LambdaColumnMapper.cs b/src/Microsoft.ML.Data/DataView/LambdaColumnMapper.cs index 9f7290e0d1..c34881b2d6 100644 --- a/src/Microsoft.ML.Data/DataView/LambdaColumnMapper.cs +++ b/src/Microsoft.ML.Data/DataView/LambdaColumnMapper.cs @@ -140,7 +140,7 @@ public Impl(IHostEnvironment env, string name, IDataView input, OneToOneColumn c Metadata.Seal(); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.Assert(false, "Shouldn't serialize this!"); throw Host.ExceptNotSupp("Shouldn't serialize this"); diff --git a/src/Microsoft.ML.Data/DataView/LambdaFilter.cs b/src/Microsoft.ML.Data/DataView/LambdaFilter.cs index 9e726ad9d4..f1cf418431 100644 --- a/src/Microsoft.ML.Data/DataView/LambdaFilter.cs +++ b/src/Microsoft.ML.Data/DataView/LambdaFilter.cs @@ -96,7 +96,7 @@ public Impl(IHostEnvironment env, string name, IDataView input, _conv = conv; } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.Assert(false, "Shouldn't serialize this!"); throw Host.ExceptNotSupp("Shouldn't serialize this"); diff --git a/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs b/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs index e3a2377009..23b0211404 100644 --- a/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs +++ b/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs @@ -131,7 +131,7 @@ public static RowToRowMapperTransform Create(IHostEnvironment env, ModelLoadCont return h.Apply("Loading Model", ch => new RowToRowMapperTransform(h, ctx, input)); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Dirty/ChooseColumnsByIndexTransform.cs b/src/Microsoft.ML.Data/Dirty/ChooseColumnsByIndexTransform.cs index 3b88307eb6..affe071745 100644 --- a/src/Microsoft.ML.Data/Dirty/ChooseColumnsByIndexTransform.cs +++ b/src/Microsoft.ML.Data/Dirty/ChooseColumnsByIndexTransform.cs @@ -223,7 +223,7 @@ public static ChooseColumnsByIndexTransform Create(IHostEnvironment env, ModelLo return h.Apply("Loading Model", ch => new ChooseColumnsByIndexTransform(h, ctx, input)); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs index c4e5ad76cd..7bade429f1 100644 --- a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs @@ -948,7 +948,7 @@ public static BinaryPerInstanceEvaluator Create(IHostEnvironment env, ModelLoadC return new BinaryPerInstanceEvaluator(env, ctx, schema); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Contracts.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -960,7 +960,7 @@ public override void Save(ModelSaveContext ctx) // float: _threshold // byte: _useRaw - base.Save(ctx); + base.SaveModel(ctx); ctx.SaveStringOrNull(_probCol); Contracts.Assert(FloatUtils.IsFinite(_threshold)); ctx.Writer.Write(_threshold); diff --git a/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs index 067dccfb50..f1c4b19b0a 100644 --- a/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs @@ -628,13 +628,13 @@ public static ClusteringPerInstanceEvaluator Create(IHostEnvironment env, ModelL return new ClusteringPerInstanceEvaluator(env, ctx, schema); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { // *** Binary format ** // base // int: number of clusters - base.Save(ctx); + base.SaveModel(ctx); Host.Assert(_numClusters > 0); ctx.Writer.Write(_numClusters); } diff --git a/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs b/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs index 88fa1ee285..f15c514e0e 100644 --- a/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs +++ b/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs @@ -510,7 +510,12 @@ protected PerInstanceEvaluatorBase(IHostEnvironment env, ModelLoadContext ctx, throw Host.ExceptSchemaMismatch(nameof(schema), "score", ScoreCol); } - public virtual void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx); + + /// + /// Derived class, for example A, should overwrite so that (()A).Save(ctx) can correctly dump A. + /// + private protected virtual void SaveModel(ModelSaveContext ctx) { // *** Binary format ** // int: Id of the score column name diff --git a/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs index 7e78be4a75..bb0a27d87d 100644 --- a/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs @@ -631,7 +631,7 @@ public static MultiClassPerInstanceEvaluator Create(IHostEnvironment env, ModelL return new MultiClassPerInstanceEvaluator(env, ctx, schema); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -642,7 +642,7 @@ public override void Save(ModelSaveContext ctx) // int: number of classes // int[]: Ids of the class names - base.Save(ctx); + base.SaveModel(ctx); Host.Assert(_numClasses > 0); ctx.Writer.Write(_numClasses); for (int i = 0; i < _numClasses; i++) diff --git a/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs index 1e232aff0d..6757da940c 100644 --- a/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs @@ -426,7 +426,7 @@ public static MultiOutputRegressionPerInstanceEvaluator Create(IHostEnvironment return new MultiOutputRegressionPerInstanceEvaluator(env, ctx, schema); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Contracts.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -434,7 +434,7 @@ public override void Save(ModelSaveContext ctx) // *** Binary format ** // base - base.Save(ctx); + base.SaveModel(ctx); } private protected override Func GetDependenciesCore(Func activeOutput) diff --git a/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs index 1332c58948..1f16938682 100644 --- a/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs @@ -324,7 +324,7 @@ public static QuantileRegressionPerInstanceEvaluator Create(IHostEnvironment env return new QuantileRegressionPerInstanceEvaluator(env, ctx, schema); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Contracts.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -335,7 +335,7 @@ public override void Save(ModelSaveContext ctx) // int: _scoreSize // int[]: Ids of the quantile names - base.Save(ctx); + base.SaveModel(ctx); Host.Assert(_scoreSize > 0); ctx.Writer.Write(_scoreSize); var quantiles = _quantiles.GetValues(); diff --git a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs index ed215e8387..b420b1ad27 100644 --- a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs @@ -597,11 +597,11 @@ public static RankerPerInstanceTransform Create(IHostEnvironment env, ModelLoadC return h.Apply("Loading Model", ch => new RankerPerInstanceTransform(h, ctx, input)); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); - _transform.Save(ctx); + ((ICanSaveModel)_transform).Save(ctx); } public long? GetRowCount() @@ -715,7 +715,7 @@ public Transform(IHostEnvironment env, ModelLoadContext ctx, IDataView input) _bindings = new Bindings(Host, input.Schema, false, LabelCol, ScoreCol, GroupCol, _truncationLevel); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.AssertValue(ctx); @@ -725,7 +725,7 @@ public override void Save(ModelSaveContext ctx) // int: _labelGains.Length // double[]: _labelGains - base.Save(ctx); + base.SaveModel(ctx); Host.Assert(0 < _truncationLevel && _truncationLevel < 100); ctx.Writer.Write(_truncationLevel); ctx.Writer.WriteDoubleArray(_labelGains); diff --git a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs index 69df01da67..2015f30eea 100644 --- a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs @@ -234,7 +234,7 @@ public static RegressionPerInstanceEvaluator Create(IHostEnvironment env, ModelL return new RegressionPerInstanceEvaluator(env, ctx, schema); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Contracts.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -242,7 +242,7 @@ public override void Save(ModelSaveContext ctx) // *** Binary format ** // base - base.Save(ctx); + base.SaveModel(ctx); } private protected override Func GetDependenciesCore(Func activeOutput) diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index ef942c5cdf..2ee036d760 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -407,7 +407,7 @@ private static ValueMapperCalibratedModelParameters Crea return new ValueMapperCalibratedModelParameters(env, ctx); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { Contracts.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -462,7 +462,7 @@ private static FeatureWeightsCalibratedModelParameters C return new FeatureWeightsCalibratedModelParameters(env, ctx); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -528,7 +528,7 @@ private static ParameterMixingCalibratedModelParameters return new ParameterMixingCalibratedModelParameters(env, ctx); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -698,7 +698,7 @@ private static SchemaBindableCalibratedModelParameters C return new SchemaBindableCalibratedModelParameters(env, ctx); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.CheckAtModel(); @@ -1488,7 +1488,7 @@ private static PlattCalibrator Create(IHostEnvironment env, ModelLoadContext ctx return new PlattCalibrator(env, ctx); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { _host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs b/src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs index ca384c0e4e..f53bb6993e 100644 --- a/src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs +++ b/src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs @@ -200,7 +200,7 @@ internal CalibratorTransformer(IHostEnvironment env, ModelLoadContext ctx, strin bool ITransformer.IsRowToRowMapper => true; - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.CheckAtModel(); @@ -245,7 +245,7 @@ internal Mapper(CalibratorTransformer parent, TCalibrator calibrato private protected override Func GetDependenciesCore(Func activeOutput) => col => col == _scoreColIndex; - public override void Save(ModelSaveContext ctx) => _parent.Save(ctx); + private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx); protected override Schema.DetachedColumn[] GetOutputColumnsCore() { diff --git a/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculation.cs b/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculation.cs index afb1f9bb51..5f837507db 100644 --- a/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculation.cs +++ b/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculation.cs @@ -163,7 +163,7 @@ public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema) return new RowMapper(env, this, schema); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { Contracts.CheckValue(ctx, nameof(ctx)); ctx.SetVersionInfo(GetVersionInfo()); diff --git a/src/Microsoft.ML.Data/Scorers/GenericScorer.cs b/src/Microsoft.ML.Data/Scorers/GenericScorer.cs index 60f5864f9f..612231ebf8 100644 --- a/src/Microsoft.ML.Data/Scorers/GenericScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/GenericScorer.cs @@ -114,7 +114,7 @@ public static Bindings Create(ModelLoadContext ctx, return Create(env, bindable, input, roles, suffix, user: false); } - public override void Save(ModelSaveContext ctx) + internal override void SaveModel(ModelSaveContext ctx) { Contracts.AssertValue(ctx); @@ -205,7 +205,7 @@ private protected override void SaveCore(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.SetVersionInfo(GetVersionInfo()); - _bindings.Save(ctx); + _bindings.SaveModel(ctx); } void ISaveAsPfa.SaveAsPfa(BoundPfaContext ctx) diff --git a/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs index 2f57bab274..25366f02f8 100644 --- a/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs @@ -165,7 +165,7 @@ private static ISchemaBindableMapper Create(IHostEnvironment env, ModelLoadConte return h.Apply("Loading Model", ch => new LabelNameBindableMapper(h, ctx)); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { Contracts.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs index a72b09906c..bf2ec96e10 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs @@ -160,7 +160,7 @@ public static BindingsImpl Create(ModelLoadContext ctx, Schema input, return new BindingsImpl(input, rowMapper, suffix, scoreKind, false, scoreColIndex, predColType); } - public override void Save(ModelSaveContext ctx) + internal override void SaveModel(ModelSaveContext ctx) { Contracts.AssertValue(ctx); @@ -335,7 +335,7 @@ private protected PredictedLabelScorerBase(IHost host, ModelLoadContext ctx, IDa private protected override void SaveCore(ModelSaveContext ctx) { Host.AssertValue(ctx); - Bindings.Save(ctx); + Bindings.SaveModel(ctx); } void ISaveAsPfa.SaveAsPfa(BoundPfaContext ctx) diff --git a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs index cbbb2bbe19..64f71aba83 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs @@ -134,7 +134,11 @@ public IRowToRowMapper GetRowToRowMapper(Schema inputSchema) return (IRowToRowMapper)Scorer.ApplyToData(Host, new EmptyDataView(Host, inputSchema)); } - protected void SaveModel(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx); + + private protected abstract void SaveModel(ModelSaveContext ctx); + + protected void SaveModelCore(ModelSaveContext ctx) { // *** Binary format *** // @@ -157,7 +161,7 @@ protected void SaveModel(ModelSaveContext ctx) /// Those are all the transformers that work with one feature column. /// /// The model used to transform the data. - public abstract class SingleFeaturePredictionTransformerBase : PredictionTransformerBase, ISingleFeaturePredictionTransformer, ICanSaveModel + public abstract class SingleFeaturePredictionTransformerBase : PredictionTransformerBase, ISingleFeaturePredictionTransformer where TModel : class { /// @@ -226,7 +230,7 @@ public sealed override Schema GetOutputSchema(Schema inputSchema) return Transform(new EmptyDataView(Host, inputSchema)).Schema; } - public void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -235,7 +239,7 @@ public void Save(ModelSaveContext ctx) protected virtual void SaveCore(ModelSaveContext ctx) { - SaveModel(ctx); + SaveModelCore(ctx); ctx.SaveStringOrNull(FeatureColumn); } diff --git a/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs b/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs index fb68553672..082b083a40 100644 --- a/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs +++ b/src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs @@ -49,7 +49,7 @@ private protected RowToRowScorerBase(IHost host, ModelLoadContext ctx, IDataView ctx.LoadModel(host, out Bindable, "SchemaBindableMapper"); } - public sealed override void Save(ModelSaveContext ctx) + private protected sealed override void SaveModel(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.CheckAtModel(); @@ -417,7 +417,7 @@ protected void SaveBase(ModelSaveContext ctx) } } - public abstract void Save(ModelSaveContext ctx); + internal abstract void SaveModel(ModelSaveContext ctx); protected override ColumnType GetColumnTypeCore(int iinfo) { diff --git a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs index 5d5e9e3178..5b62f45a7e 100644 --- a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs +++ b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs @@ -76,7 +76,9 @@ protected SchemaBindablePredictorWrapperBase(IHostEnvironment env, ModelLoadCont ScoreType = GetScoreType(Predictor, out ValueMapper); } - public virtual void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx); + + private protected virtual void SaveModel(ModelSaveContext ctx) { Contracts.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -283,11 +285,11 @@ public static SchemaBindablePredictorWrapper Create(IHostEnvironment env, ModelL return new SchemaBindablePredictorWrapper(env, ctx); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Contracts.CheckValue(ctx, nameof(ctx)); ctx.SetVersionInfo(GetVersionInfo()); - base.Save(ctx); + base.SaveModel(ctx); } private protected override void SaveAsPfaCore(BoundPfaContext ctx, RoleMappedSchema schema, string[] outputNames) @@ -390,11 +392,11 @@ public static SchemaBindableBinaryPredictorWrapper Create(IHostEnvironment env, return new SchemaBindableBinaryPredictorWrapper(env, ctx); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Contracts.CheckValue(ctx, nameof(ctx)); ctx.SetVersionInfo(GetVersionInfo()); - base.Save(ctx); + base.SaveModel(ctx); } private protected override void SaveAsPfaCore(BoundPfaContext ctx, RoleMappedSchema schema, string[] outputNames) @@ -631,7 +633,7 @@ private SchemaBindableQuantileRegressionPredictor(IHostEnvironment env, ModelLoa Contracts.CheckDecode(Utils.Size(_quantiles) > 0); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Contracts.CheckValue(ctx, nameof(ctx)); ctx.SetVersionInfo(GetVersionInfo()); @@ -641,7 +643,7 @@ public override void Save(ModelSaveContext ctx) // int: the number of quantiles // double[]: the quantiles - base.Save(ctx); + base.SaveModel(ctx); ctx.Writer.WriteDoubleArray(_quantiles); } diff --git a/src/Microsoft.ML.Data/Transforms/BootstrapSamplingTransformer.cs b/src/Microsoft.ML.Data/Transforms/BootstrapSamplingTransformer.cs index cb61c00db1..ee3289396f 100644 --- a/src/Microsoft.ML.Data/Transforms/BootstrapSamplingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/BootstrapSamplingTransformer.cs @@ -129,7 +129,7 @@ private BootstrapSamplingTransformer(IHost host, ModelLoadContext ctx, IDataView Host.CheckDecode(_poolSize >= 0); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs b/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs index f9b5d6bf8e..8206e66723 100644 --- a/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs @@ -263,7 +263,7 @@ private static VersionInfo GetVersionInfo() private const int VersionAddedAliases = 0x00010002; private const int VersionTransformer = 0x00010003; - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -864,7 +864,7 @@ private protected override Func GetDependenciesCore(Func a protected override Schema.DetachedColumn[] GetOutputColumnsCore() => _columns.Select(x => x.MakeSchemaColumn()).ToArray(); - public override void Save(ModelSaveContext ctx) => _parent.Save(ctx); + private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx); protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { diff --git a/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs b/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs index 7dfdf17d23..07e5784bad 100644 --- a/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs +++ b/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs @@ -166,7 +166,7 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema) => Create(env, ctx).MakeRowMapper(inputSchema); - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { ctx.SetVersionInfo(GetVersionInfo()); SaveColumns(ctx); diff --git a/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs b/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs index c3e6b50018..00bd42fee0 100644 --- a/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs +++ b/src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs @@ -124,7 +124,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) /// /// The allows the user to specify columns to drop or keep from a given input. /// - public sealed class ColumnSelectingTransformer : ITransformer, ICanSaveModel + public sealed class ColumnSelectingTransformer : ITransformer { internal const string Summary = "Selects which columns from the dataset to keep."; internal const string UserName = "Select Columns Transform"; @@ -417,7 +417,9 @@ private static IDataTransform Create(IHostEnvironment env, Options options, IDat return new SelectColumnsDataTransform(env, transform, new Mapper(transform, input.Schema), input); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx); + + internal void SaveModel(ModelSaveContext ctx) { ctx.SetVersionInfo(GetVersionInfo()); @@ -678,7 +680,7 @@ public RowCursor[] GetRowCursorSet(IEnumerable columnsNeeded, int return cursors; } - public void Save(ModelSaveContext ctx) => _transform.Save(ctx); + void ICanSaveModel.Save(ModelSaveContext ctx) => _transform.SaveModel(ctx); public Func GetDependencies(Func activeOutput) { diff --git a/src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransformer.cs b/src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransformer.cs index 461a3f61e9..f6f46d62ac 100644 --- a/src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransformer.cs @@ -172,7 +172,7 @@ private FeatureContributionCalculatingTransformer(IHostEnvironment env, ModelLoa Normalize = ctx.Reader.ReadBoolByte(); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.SetVersionInfo(GetVersionInfo()); diff --git a/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs b/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs index 54fb453894..c14cdd49d9 100644 --- a/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs @@ -163,7 +163,7 @@ public static Bindings Create(ModelLoadContext ctx, Schema input) return new Bindings(useCounter, states, input, false, names); } - public void Save(ModelSaveContext ctx) + internal void Save(ModelSaveContext ctx) { Contracts.AssertValue(ctx); @@ -309,7 +309,7 @@ public static GenerateNumberTransform Create(IHostEnvironment env, ModelLoadCont return h.Apply("Loading Model", ch => new GenerateNumberTransform(h, ctx, input)); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Transforms/Hashing.cs b/src/Microsoft.ML.Data/Transforms/Hashing.cs index f98f4baa47..d60794d2d8 100644 --- a/src/Microsoft.ML.Data/Transforms/Hashing.cs +++ b/src/Microsoft.ML.Data/Transforms/Hashing.cs @@ -278,7 +278,7 @@ private HashingTransformer(IHost host, ModelLoadContext ctx) TextModelHelper.LoadAll(Host, ctx, columnsLength, out _keyValues, out _kvTypes); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); diff --git a/src/Microsoft.ML.Data/Transforms/KeyToValue.cs b/src/Microsoft.ML.Data/Transforms/KeyToValue.cs index ffb6c61da3..eaaf566e95 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToValue.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToValue.cs @@ -144,7 +144,7 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema) => Create(env, ctx).MakeRowMapper(inputSchema); - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVector.cs b/src/Microsoft.ML.Data/Transforms/KeyToVector.cs index d2c7a5e260..cc4ed51905 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVector.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVector.cs @@ -143,7 +143,7 @@ private static VersionInfo GetVersionInfo() loaderAssemblyName: typeof(KeyToVectorMappingTransformer).Assembly.FullName); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs b/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs index 408e62e330..cf0b9790ea 100644 --- a/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs @@ -120,7 +120,7 @@ public static LabelConvertTransform Create(IHostEnvironment env, ModelLoadContex }); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs b/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs index 21b0b13bce..e848312ee9 100644 --- a/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs @@ -98,7 +98,7 @@ public static LabelIndicatorTransform Create(IHostEnvironment env, ch => new LabelIndicatorTransform(h, args, input)); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Transforms/NAFilter.cs b/src/Microsoft.ML.Data/Transforms/NAFilter.cs index ffc04cbc80..956f201cd6 100644 --- a/src/Microsoft.ML.Data/Transforms/NAFilter.cs +++ b/src/Microsoft.ML.Data/Transforms/NAFilter.cs @@ -169,7 +169,7 @@ public static NAFilter Create(IHostEnvironment env, ModelLoadContext ctx, IDataV return h.Apply("Loading Model", ch => new NAFilter(h, ctx, input)); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Transforms/NopTransform.cs b/src/Microsoft.ML.Data/Transforms/NopTransform.cs index 92f701d704..234c060be2 100644 --- a/src/Microsoft.ML.Data/Transforms/NopTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/NopTransform.cs @@ -89,7 +89,7 @@ private NopTransform(IHost host, ModelLoadContext ctx, IDataView input) // Nothing :) } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { _host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs index 4888aeb4c4..f15f3e77fc 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs @@ -369,7 +369,9 @@ private AffineColumnFunction(IHost host) Host = host; } - public abstract void Save(ModelSaveContext ctx); + void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx); + + private protected abstract void SaveModel(ModelSaveContext ctx); public abstract JToken PfaInfo(BoundPfaContext ctx, JToken srcToken); public bool CanSaveOnnx(OnnxContext ctx) => true; @@ -485,7 +487,9 @@ private CdfColumnFunction(IHost host) Host = host; } - public abstract void Save(ModelSaveContext ctx); + void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx); + + private protected abstract void SaveModel(ModelSaveContext ctx); public JToken PfaInfo(BoundPfaContext ctx, JToken srcToken) => null; @@ -614,7 +618,9 @@ protected BinColumnFunction(IHost host) Host = host; } - public abstract void Save(ModelSaveContext ctx); + void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx); + + private protected abstract void SaveModel(ModelSaveContext ctx); public JToken PfaInfo(BoundPfaContext ctx, JToken srcToken) => null; diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs index 54b4c11ce4..e2050bdcdc 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs @@ -568,7 +568,7 @@ private void GetResult(ref TFloat input, ref TFloat value) value = (input - Offset) * Scale; } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { AffineNormSerializationUtils.SaveModel(ctx, 1, null, new[] { Scale }, new[] { Offset }, saveText: true); } @@ -628,7 +628,7 @@ public static ImplVec Create(ModelLoadContext ctx, IHost host, VectorType typeSr return new ImplVec(host, scales, offsets, (offsets != null && nz.Count < cv / 2) ? nz.ToArray() : null); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { AffineNormSerializationUtils.SaveModel(ctx, Scale.Length, null, Scale, Offset, saveText: true); } @@ -892,7 +892,7 @@ private void GetResult(ref TFloat input, ref TFloat value) value = CdfUtils.Cdf(val, Mean, Stddev); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.CheckAtModel(); @@ -947,7 +947,7 @@ public static ImplVec Create(ModelLoadContext ctx, IHost host, VectorType typeSr return new ImplVec(host, mean, stddev, useLog); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.CheckAtModel(); @@ -1071,7 +1071,7 @@ public ImplOne(IHost host, TFloat[] binUpperBounds, bool fixZero) return new ImplOne(host, binUpperBounds[0], fixZero); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.CheckAtModel(); @@ -1157,7 +1157,7 @@ public static ImplVec Create(ModelLoadContext ctx, IHost host, VectorType typeSr return new ImplVec(host, binUpperBounds, fixZero); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs index 482e5d8f71..7a4abcdfbb 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumnSng.cs @@ -570,7 +570,7 @@ private void GetResult(ref TFloat input, ref TFloat value) value = (input - Offset) * Scale; } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { AffineNormSerializationUtils.SaveModel(ctx, 1, null, new[] { Scale }, new[] { Offset }, saveText: true); } @@ -629,7 +629,7 @@ public static ImplVec Create(ModelLoadContext ctx, IHost host, VectorType typeSr return new ImplVec(host, scales, offsets, (offsets != null && nz.Count < cv / 2) ? nz.ToArray() : null); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { AffineNormSerializationUtils.SaveModel(ctx, Scale.Length, null, Scale, Offset, saveText: true); } @@ -896,7 +896,7 @@ private void GetResult(ref TFloat input, ref TFloat value) value = CdfUtils.Cdf(val, Mean, Stddev); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.CheckAtModel(); @@ -951,7 +951,7 @@ public static ImplVec Create(ModelLoadContext ctx, IHost host, VectorType typeSr return new ImplVec(host, mean, stddev, useLog); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.CheckAtModel(); @@ -1076,7 +1076,7 @@ public ImplOne(IHost host, TFloat[] binUpperBounds, bool fixZero) return new ImplOne(host, binUpperBounds[0], fixZero); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.CheckAtModel(); @@ -1162,7 +1162,7 @@ public static ImplVec Create(ModelLoadContext ctx, IHost host, VectorType typeSr return new ImplVec(host, binUpperBounds, fixZero); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Transforms/Normalizer.cs b/src/Microsoft.ML.Data/Transforms/Normalizer.cs index bff9503614..78edabbe0c 100644 --- a/src/Microsoft.ML.Data/Transforms/Normalizer.cs +++ b/src/Microsoft.ML.Data/Transforms/Normalizer.cs @@ -540,7 +540,7 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema) => Create(env, ctx).MakeRowMapper(inputSchema); - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs b/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs index f1cf1f52b3..4a4a94f7de 100644 --- a/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs +++ b/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs @@ -110,7 +110,7 @@ private protected override Func GetDependenciesCore(Func a return col => active[col]; } - public override void Save(ModelSaveContext ctx) => _parent.Save(ctx); + private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx); } } } diff --git a/src/Microsoft.ML.Data/Transforms/PerGroupTransformBase.cs b/src/Microsoft.ML.Data/Transforms/PerGroupTransformBase.cs index c5e95a7683..f226d28a30 100644 --- a/src/Microsoft.ML.Data/Transforms/PerGroupTransformBase.cs +++ b/src/Microsoft.ML.Data/Transforms/PerGroupTransformBase.cs @@ -133,7 +133,9 @@ protected PerGroupTransformBase(IHostEnvironment env, ModelLoadContext ctx, IDat GroupCol = ctx.LoadNonEmptyString(); } - public virtual void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx); + + private protected virtual void SaveModel(ModelSaveContext ctx) { Host.AssertValue(ctx); diff --git a/src/Microsoft.ML.Data/Transforms/RangeFilter.cs b/src/Microsoft.ML.Data/Transforms/RangeFilter.cs index 838f1698b2..e6f356ffe4 100644 --- a/src/Microsoft.ML.Data/Transforms/RangeFilter.cs +++ b/src/Microsoft.ML.Data/Transforms/RangeFilter.cs @@ -177,7 +177,7 @@ public static RangeFilter Create(IHostEnvironment env, ModelLoadContext ctx, IDa return h.Apply("Loading Model", ch => new RangeFilter(h, ctx, input)); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs b/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs index 0ed952c379..0f6b6a5dc3 100644 --- a/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs @@ -159,7 +159,7 @@ public static RowShufflingTransformer Create(IHostEnvironment env, ModelLoadCont return h.Apply("Loading Model", ch => new RowShufflingTransformer(h, ctx, input)); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs b/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs index 104cdad96d..13622d6063 100644 --- a/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs +++ b/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs @@ -13,7 +13,7 @@ namespace Microsoft.ML.Data /// /// Base class for transformer which produce new columns, but doesn't affect existing ones. /// - public abstract class RowToRowTransformerBase : ITransformer, ICanSaveModel + public abstract class RowToRowTransformerBase : ITransformer { protected readonly IHost Host; @@ -23,7 +23,9 @@ protected RowToRowTransformerBase(IHost host) Host = host; } - public abstract void Save(ModelSaveContext ctx); + void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx); + + private protected abstract void SaveModel(ModelSaveContext ctx); public bool IsRowToRowMapper => true; @@ -109,7 +111,9 @@ Func IRowMapper.GetDependencies(Func activeOutput) [BestFriend] private protected abstract Func GetDependenciesCore(Func activeOutput); - public abstract void Save(ModelSaveContext ctx); + void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx); + + private protected abstract void SaveModel(ModelSaveContext ctx); public ITransformer GetTransformer() => _parent; } diff --git a/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs b/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs index 6c8a316a7c..b83e701597 100644 --- a/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs +++ b/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs @@ -168,7 +168,7 @@ public static SkipTakeFilter Create(IHostEnvironment env, ModelLoadContext ctx, } ///Saves class data to context - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs b/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs index 9b57d6d031..037521e7e1 100644 --- a/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs @@ -324,7 +324,7 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema) => Create(env, ctx).MakeRowMapper(inputSchema); - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Transforms/TransformBase.cs b/src/Microsoft.ML.Data/Transforms/TransformBase.cs index ec8dec1af9..0ab2b61c27 100644 --- a/src/Microsoft.ML.Data/Transforms/TransformBase.cs +++ b/src/Microsoft.ML.Data/Transforms/TransformBase.cs @@ -43,7 +43,9 @@ protected TransformBase(IHost host, IDataView input) Source = input; } - public abstract void Save(ModelSaveContext ctx); + void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx); + + private protected abstract void SaveModel(ModelSaveContext ctx); public abstract long? GetRowCount(); @@ -388,7 +390,7 @@ public static Bindings Create(OneToOneTransformBase parent, ModelLoadContext ctx return new Bindings(parent, infos, inputSchema, false, names); } - public void Save(ModelSaveContext ctx) + internal void Save(ModelSaveContext ctx) { Contracts.AssertValue(ctx); diff --git a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs index 5bdbdb0a29..7dcd1ab641 100644 --- a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs +++ b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs @@ -206,7 +206,7 @@ internal TypeConvertingTransformer(IHostEnvironment env, params TypeConvertingEs _columns = columns.ToArray(); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Data/Transforms/ValueMapping.cs b/src/Microsoft.ML.Data/Transforms/ValueMapping.cs index b6469b55cc..c74650a7a0 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueMapping.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueMapping.cs @@ -740,7 +740,7 @@ protected static PrimitiveType GetPrimitiveType(Type rawType, out bool isVectorT return ColumnTypeExtensions.PrimitiveTypeFromKind(kind); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.SetVersionInfo(GetVersionInfo()); diff --git a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs index 011d491608..dacd543f29 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs @@ -641,7 +641,7 @@ private static TermMap[] Train(IHostEnvironment env, IChannel ch, ColInfo[] info return termMap; } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Average.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Average.cs index 05b85b37ec..4582d4f4a9 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/Average.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Average.cs @@ -12,7 +12,7 @@ namespace Microsoft.ML.Ensemble.OutputCombiners { - public sealed class Average : BaseAverager, ICanSaveModel, IRegressionOutputCombiner + public sealed class Average : BaseAverager, IRegressionOutputCombiner { public const string UserName = "Average"; public const string LoadName = "Average"; diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseAverager.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseAverager.cs index 1ffbe13ec1..0b454be238 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseAverager.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseAverager.cs @@ -7,7 +7,7 @@ namespace Microsoft.ML.Ensemble.OutputCombiners { - public abstract class BaseAverager : IBinaryOutputCombiner + public abstract class BaseAverager : IBinaryOutputCombiner, ICanSaveModel { protected readonly IHost Host; public BaseAverager(IHostEnvironment env, string name) @@ -30,7 +30,7 @@ protected BaseAverager(IHostEnvironment env, string name, ModelLoadContext ctx) Host.CheckDecode(cbFloat == sizeof(Single)); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiCombiner.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiCombiner.cs index 5e8296862b..737cbfd649 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiCombiner.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseMultiCombiner.cs @@ -11,7 +11,7 @@ namespace Microsoft.ML.Ensemble.OutputCombiners { - public abstract class BaseMultiCombiner : IMultiClassOutputCombiner + public abstract class BaseMultiCombiner : IMultiClassOutputCombiner, ICanSaveModel { protected readonly IHost Host; @@ -49,7 +49,7 @@ internal BaseMultiCombiner(IHostEnvironment env, string name, ModelLoadContext c Normalize = ctx.Reader.ReadBoolByte(); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs index 81d0c545f9..c081feee14 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs @@ -15,7 +15,7 @@ namespace Microsoft.ML.Ensemble.OutputCombiners { - internal abstract class BaseStacking : IStackingTrainer + internal abstract class BaseStacking : IStackingTrainer, ICanSaveModel { public abstract class ArgumentsBase { @@ -67,7 +67,7 @@ private protected BaseStacking(IHostEnvironment env, string name, ModelLoadConte CheckMeta(); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { Host.Check(Meta != null, "Can't save an untrained Stacking combiner"); Host.CheckValue(ctx, nameof(ctx)); diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Median.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Median.cs index 3b94196eb9..5173216711 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/Median.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Median.cs @@ -59,7 +59,7 @@ public static Median Create(IHostEnvironment env, ModelLoadContext ctx) return new Median(env, ctx); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { _host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiAverage.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiAverage.cs index 1ba5cdf028..be3bd7dfa2 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiAverage.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiAverage.cs @@ -16,7 +16,7 @@ namespace Microsoft.ML.Ensemble.OutputCombiners { - public sealed class MultiAverage : BaseMultiAverager, ICanSaveModel + public sealed class MultiAverage : BaseMultiAverager { public const string LoadName = "MultiAverage"; public const string LoaderSignature = "MultiAverageCombiner"; diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiMedian.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiMedian.cs index 477a658355..94b0f38ae8 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiMedian.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiMedian.cs @@ -19,7 +19,7 @@ namespace Microsoft.ML.Ensemble.OutputCombiners /// /// Generic interface for combining outputs of multiple models /// - public sealed class MultiMedian : BaseMultiCombiner, ICanSaveModel + public sealed class MultiMedian : BaseMultiCombiner { public const string LoadName = "MultiMedian"; public const string LoaderSignature = "MultiMedianCombiner"; diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs index 3ec8ea85c8..f2bb13d029 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs @@ -20,7 +20,7 @@ namespace Microsoft.ML.Ensemble.OutputCombiners { using TVectorPredictor = IPredictorProducing>; - internal sealed class MultiStacking : BaseStacking>, ICanSaveModel, IMultiClassOutputCombiner + internal sealed class MultiStacking : BaseStacking>, IMultiClassOutputCombiner { public const string LoadName = "MultiStacking"; public const string LoaderSignature = "MultiStackingCombiner"; diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs index a8ba932926..a222606ae1 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs @@ -17,7 +17,7 @@ namespace Microsoft.ML.Ensemble.OutputCombiners { // REVIEW: Why is MultiVoting based on BaseMultiCombiner? Normalizing the model outputs // is senseless, so the base adds no real functionality. - public sealed class MultiVoting : BaseMultiCombiner, ICanSaveModel + public sealed class MultiVoting : BaseMultiCombiner { public const string LoadName = "MultiVoting"; public const string LoaderSignature = "MultiVotingCombiner"; diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiWeightedAverage.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiWeightedAverage.cs index deef23abcd..161487650a 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/MultiWeightedAverage.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/MultiWeightedAverage.cs @@ -22,7 +22,7 @@ namespace Microsoft.ML.Ensemble.OutputCombiners /// /// Generic interface for combining outputs of multiple models /// - public sealed class MultiWeightedAverage : BaseMultiAverager, IWeightedAverager, ICanSaveModel + public sealed class MultiWeightedAverage : BaseMultiAverager, IWeightedAverager { public const string UserName = "Multi Weighted Average"; public const string LoadName = "MultiWeightedAverage"; diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs index ab57cfe9aa..239e386ebf 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs @@ -19,7 +19,7 @@ namespace Microsoft.ML.Ensemble.OutputCombiners { using TScalarPredictor = IPredictorProducing; - internal sealed class RegressionStacking : BaseScalarStacking, IRegressionOutputCombiner, ICanSaveModel + internal sealed class RegressionStacking : BaseScalarStacking, IRegressionOutputCombiner { public const string LoadName = "RegressionStacking"; public const string LoaderSignature = "RegressionStacking"; diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs index 891963dbea..9999544c6c 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs @@ -16,7 +16,7 @@ namespace Microsoft.ML.Ensemble.OutputCombiners { using TScalarPredictor = IPredictorProducing; - internal sealed class Stacking : BaseScalarStacking, IBinaryOutputCombiner, ICanSaveModel + internal sealed class Stacking : BaseScalarStacking, IBinaryOutputCombiner { public const string UserName = "Stacking"; public const string LoadName = "Stacking"; diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/Voting.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/Voting.cs index b18de67b18..3700782957 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/Voting.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/Voting.cs @@ -57,7 +57,7 @@ public static Voting Create(IHostEnvironment env, ModelLoadContext ctx) return new Voting(env, ctx); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { _host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Ensemble/OutputCombiners/WeightedAverage.cs b/src/Microsoft.ML.Ensemble/OutputCombiners/WeightedAverage.cs index ec1d2dcd11..e6164b1914 100644 --- a/src/Microsoft.ML.Ensemble/OutputCombiners/WeightedAverage.cs +++ b/src/Microsoft.ML.Ensemble/OutputCombiners/WeightedAverage.cs @@ -19,7 +19,7 @@ namespace Microsoft.ML.Ensemble.OutputCombiners { - public sealed class WeightedAverage : BaseAverager, IWeightedAverager, ICanSaveModel + public sealed class WeightedAverage : BaseAverager, IWeightedAverager { public const string UserName = "Weighted Average"; public const string LoadName = "WeightedAverage"; diff --git a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs index 2e5da36c3a..6f22df98c5 100644 --- a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs +++ b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs @@ -482,7 +482,7 @@ protected SchemaBindablePipelineEnsembleBase(IHostEnvironment env, ModelLoadCont _inputCols[i] = ctx.LoadNonEmptyString(); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { Host.AssertValue(ctx); ctx.SetVersionInfo(GetVersionInfo()); diff --git a/src/Microsoft.ML.FastTree/TreeEnsemble/InternalQuantileRegressionTree.cs b/src/Microsoft.ML.FastTree/TreeEnsemble/InternalQuantileRegressionTree.cs index f93111201a..453a223449 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsemble/InternalQuantileRegressionTree.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsemble/InternalQuantileRegressionTree.cs @@ -50,7 +50,7 @@ public InternalQuantileRegressionTree(byte[] buffer, ref int position) _instanceWeights = buffer.ToDoubleArray(ref position); } - public override void Save(ModelSaveContext ctx) + internal override void Save(ModelSaveContext ctx) { // *** Binary format *** // double[]: Labels Distribution. diff --git a/src/Microsoft.ML.FastTree/TreeEnsemble/InternalRegressionTree.cs b/src/Microsoft.ML.FastTree/TreeEnsemble/InternalRegressionTree.cs index 4727919022..8feb116f1d 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsemble/InternalRegressionTree.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsemble/InternalRegressionTree.cs @@ -409,7 +409,7 @@ protected void Save(ModelSaveContext ctx, TreeType code) writer.WriteDoubleArray(_previousLeafValue); } - public virtual void Save(ModelSaveContext ctx) + internal virtual void Save(ModelSaveContext ctx) { Save(ctx, TreeType.Regression); } diff --git a/src/Microsoft.ML.FastTree/TreeEnsemble/InternalTreeEnsemble.cs b/src/Microsoft.ML.FastTree/TreeEnsemble/InternalTreeEnsemble.cs index 421f788ef6..3a7a49573f 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsemble/InternalTreeEnsemble.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsemble/InternalTreeEnsemble.cs @@ -56,7 +56,7 @@ public InternalTreeEnsemble(ModelLoadContext ctx, bool usingDefaultValues, bool _firstInputInitializationContent = ctx.LoadStringOrNull(); } - public void Save(ModelSaveContext ctx) + internal void Save(ModelSaveContext ctx) { // *** Binary format *** // int: Number of trees diff --git a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs index e1aa26f56f..18a0cf1aec 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs @@ -393,7 +393,7 @@ public TreeEnsembleFeaturizerBindableMapper(IHostEnvironment env, ModelLoadConte _totalLeafCount = CountLeaves(_ensemble); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { _host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.HalLearners/VectorWhitening.cs b/src/Microsoft.ML.HalLearners/VectorWhitening.cs index f865ba4fcd..fc298c1ed1 100644 --- a/src/Microsoft.ML.HalLearners/VectorWhitening.cs +++ b/src/Microsoft.ML.HalLearners/VectorWhitening.cs @@ -464,7 +464,7 @@ private static void TrainModels(IHostEnvironment env, IChannel ch, float[][] col } } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.ImageAnalytics/ImageGrayscale.cs b/src/Microsoft.ML.ImageAnalytics/ImageGrayscale.cs index cf5388d8da..fc9ed6c699 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageGrayscale.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageGrayscale.cs @@ -141,7 +141,7 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema) => Create(env, ctx).MakeRowMapper(inputSchema); - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); diff --git a/src/Microsoft.ML.ImageAnalytics/ImageLoader.cs b/src/Microsoft.ML.ImageAnalytics/ImageLoader.cs index 8829d1dd2f..dab2dc17c2 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageLoader.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageLoader.cs @@ -138,7 +138,7 @@ protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].inputColumnName, TextType.Instance.ToString(), inputSchema[srcCol].Type.ToString()); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); diff --git a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractor.cs b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractor.cs index a34fff871c..8359c4d9a6 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractor.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractor.cs @@ -252,7 +252,7 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema) => Create(env, ctx).MakeRowMapper(inputSchema); - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); diff --git a/src/Microsoft.ML.ImageAnalytics/ImageResizer.cs b/src/Microsoft.ML.ImageAnalytics/ImageResizer.cs index 665cfd516d..a23c0d0a93 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageResizer.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageResizer.cs @@ -235,7 +235,7 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema) => Create(env, ctx).MakeRowMapper(inputSchema); - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); diff --git a/src/Microsoft.ML.ImageAnalytics/VectorToImageTransform.cs b/src/Microsoft.ML.ImageAnalytics/VectorToImageTransform.cs index cc0a5da52e..da26e1f5dc 100644 --- a/src/Microsoft.ML.ImageAnalytics/VectorToImageTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/VectorToImageTransform.cs @@ -195,7 +195,7 @@ public ColInfoEx(ModelLoadContext ctx) Interleave = ctx.Reader.ReadBoolByte(); } - public void Save(ModelSaveContext ctx) + internal void Save(ModelSaveContext ctx) { Contracts.AssertValue(ctx); @@ -306,7 +306,7 @@ public static VectorToImageTransform Create(IHostEnvironment env, ModelLoadConte }); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.OnnxTransform/OnnxTransform.cs b/src/Microsoft.ML.OnnxTransform/OnnxTransform.cs index 30f20846c4..22bb76c019 100644 --- a/src/Microsoft.ML.OnnxTransform/OnnxTransform.cs +++ b/src/Microsoft.ML.OnnxTransform/OnnxTransform.cs @@ -261,7 +261,7 @@ internal OnnxTransformer(IHostEnvironment env, string[] outputColumnNames, strin { } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.AssertValue(ctx); @@ -368,7 +368,7 @@ private protected override Func GetDependenciesCore(Func a return col => Enumerable.Range(0, _parent.Outputs.Length).Any(i => activeOutput(i)) && _inputColIndices.Any(i => i == col); } - public override void Save(ModelSaveContext ctx) => _parent.Save(ctx); + private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx); private interface INamedOnnxValueGetter { diff --git a/src/Microsoft.ML.PCA/PcaTransformer.cs b/src/Microsoft.ML.PCA/PcaTransformer.cs index b6b134a1e0..dd17255d0a 100644 --- a/src/Microsoft.ML.PCA/PcaTransformer.cs +++ b/src/Microsoft.ML.PCA/PcaTransformer.cs @@ -140,7 +140,7 @@ public TransformInfo(ModelLoadContext ctx) Contracts.CheckDecode(MeanProjected == null || (MeanProjected.Length == Rank && FloatUtils.IsFinite(MeanProjected))); } - public void Save(ModelSaveContext ctx) + internal void Save(ModelSaveContext ctx) { Contracts.AssertValue(ctx); @@ -279,7 +279,7 @@ private static PrincipalComponentAnalysisTransformer Create(IHostEnvironment env return new PrincipalComponentAnalysisTransformer(host, ctx); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Parquet/ParquetLoader.cs b/src/Microsoft.ML.Parquet/ParquetLoader.cs index 0f733273fb..9f5eaa97d3 100644 --- a/src/Microsoft.ML.Parquet/ParquetLoader.cs +++ b/src/Microsoft.ML.Parquet/ParquetLoader.cs @@ -406,7 +406,7 @@ public RowCursor[] GetRowCursorSet(IEnumerable columnsNee return new RowCursor[] { GetRowCursor(columnsNeeded, rand) }; } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { Contracts.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Parquet/PartitionedFileLoader.cs b/src/Microsoft.ML.Parquet/PartitionedFileLoader.cs index 8f2c83a17f..98c83a59e2 100644 --- a/src/Microsoft.ML.Parquet/PartitionedFileLoader.cs +++ b/src/Microsoft.ML.Parquet/PartitionedFileLoader.cs @@ -253,7 +253,7 @@ public static PartitionedFileLoader Create(IHostEnvironment env, ModelLoadContex ch => new PartitionedFileLoader(host, ctx, files)); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { Contracts.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Parquet/PartitionedPathParser.cs b/src/Microsoft.ML.Parquet/PartitionedPathParser.cs index 06ae2b9700..70a01be64b 100644 --- a/src/Microsoft.ML.Parquet/PartitionedPathParser.cs +++ b/src/Microsoft.ML.Parquet/PartitionedPathParser.cs @@ -148,7 +148,7 @@ public static SimplePartitionedPathParser Create(IHostEnvironment env, ModelLoad ch => new SimplePartitionedPathParser(host, ctx)); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { Contracts.CheckValue(ctx, nameof(ctx)); ctx.SetVersionInfo(GetVersionInfo()); @@ -261,7 +261,7 @@ public static ParquetPartitionedPathParser Create(IHostEnvironment env, ModelLoa ch => new ParquetPartitionedPathParser(host, ctx)); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { Contracts.CheckValue(ctx, nameof(ctx)); ctx.SetVersionInfo(GetVersionInfo()); diff --git a/src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs b/src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs index 901a535307..a0007d9108 100644 --- a/src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs +++ b/src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs @@ -155,7 +155,7 @@ private static MatrixFactorizationModelParameters Create(IHostEnvironment env, M /// /// Save model to the given context /// - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); @@ -391,7 +391,7 @@ public Row GetRow(Row input, Func active) /// /// Trains a . It factorizes the training matrix into the product of two low-rank matrices. /// - public sealed class MatrixFactorizationPredictionTransformer : PredictionTransformerBase, ICanSaveModel + public sealed class MatrixFactorizationPredictionTransformer : PredictionTransformerBase { internal const string LoaderSignature = "MaFactPredXf"; internal string MatrixColumnIndexColumnName { get; } @@ -488,7 +488,7 @@ public override Schema GetOutputSchema(Schema inputSchema) return Transform(new EmptyDataView(Host, inputSchema)).Schema; } - public void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineModelParameters.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineModelParameters.cs index 4bf005b25a..9aac84dae1 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineModelParameters.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineModelParameters.cs @@ -283,7 +283,7 @@ public float[] GetLatentWeights() } } - public sealed class FieldAwareFactorizationMachinePredictionTransformer : PredictionTransformerBase, ICanSaveModel + public sealed class FieldAwareFactorizationMachinePredictionTransformer : PredictionTransformerBase { public const string LoaderSignature = "FAFMPredXfer"; @@ -387,7 +387,7 @@ public override Schema GetOutputSchema(Schema inputSchema) /// Saves the transformer to file. /// /// The that facilitates saving to the . - public void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs b/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs index d2efdc8d57..b5cb2813af 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs @@ -166,7 +166,7 @@ internal static LinearModelStatistics Create(IHostEnvironment env, ModelLoadCont return new LinearModelStatistics(env, ctx); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { Contracts.AssertValue(_env); _env.CheckValue(ctx, nameof(ctx)); diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs index bef2bf1b1f..f87c2da37c 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs @@ -709,7 +709,7 @@ internal static (TFDataType[] tfOutputTypes, ColumnType[] outputTypes) GetOutput private protected override IRowMapper MakeRowMapper(Schema inputSchema) => new Mapper(this, inputSchema); - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.AssertValue(ctx); ctx.CheckAtModel(); @@ -877,7 +877,7 @@ public Mapper(TensorFlowTransformer parent, Schema inputSchema) : } } - public override void Save(ModelSaveContext ctx) => _parent.Save(ctx); + private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx); private class OutputCache { diff --git a/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs b/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs index 4a3593ca14..5a5f40a0b8 100644 --- a/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs +++ b/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs @@ -465,7 +465,7 @@ public AdaptiveSingularSpectrumSequenceModeler(IHostEnvironment env, ModelLoadCo _xSmooth = new CpuAlignedVector(_windowSize, CpuMathUtils.GetVectorAlignment()); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { _host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.TimeSeries/ExponentialAverageTransform.cs b/src/Microsoft.ML.TimeSeries/ExponentialAverageTransform.cs index 15e7a84a8f..a7a183cd63 100644 --- a/src/Microsoft.ML.TimeSeries/ExponentialAverageTransform.cs +++ b/src/Microsoft.ML.TimeSeries/ExponentialAverageTransform.cs @@ -77,7 +77,7 @@ public ExponentialAverageTransform(IHostEnvironment env, ModelLoadContext ctx, I Host.CheckDecode(WindowSize == 1); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); Host.Assert(WindowSize >= 1); @@ -89,7 +89,7 @@ public override void Save(ModelSaveContext ctx) // // Single _decay - base.Save(ctx); + base.SaveModel(ctx); ctx.Writer.Write(_decay); } diff --git a/src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs b/src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs index eeff5db3d2..f552ed2860 100644 --- a/src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs +++ b/src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs @@ -48,11 +48,11 @@ public override Schema GetOutputSchema(Schema inputSchema) return Transform(new EmptyDataView(Host, inputSchema)).Schema; } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { ctx.CheckAtModel(); Host.Assert(InitialWindowSize == 0); - base.Save(ctx); + base.SaveModel(ctx); // *** Binary format *** // diff --git a/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs b/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs index cb0874e2ce..873b93a643 100644 --- a/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs +++ b/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs @@ -175,7 +175,7 @@ private IidChangePointDetector(IHostEnvironment env, IidChangePointDetector tran { } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -187,7 +187,7 @@ public override void Save(ModelSaveContext ctx) // *** Binary format *** // - base.Save(ctx); + base.SaveModel(ctx); } // Factory method for SignatureLoadRowMapper. diff --git a/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs b/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs index 20cecfc285..a0930d110c 100644 --- a/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs +++ b/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs @@ -155,7 +155,7 @@ private IidSpikeDetector(IHostEnvironment env, IidSpikeDetector transform) { } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -166,7 +166,7 @@ public override void Save(ModelSaveContext ctx) // *** Binary format *** // - base.Save(ctx); + base.SaveModel(ctx); } // Factory method for SignatureLoadRowMapper. diff --git a/src/Microsoft.ML.TimeSeries/MovingAverageTransform.cs b/src/Microsoft.ML.TimeSeries/MovingAverageTransform.cs index f207c7b75a..8cea2e751f 100644 --- a/src/Microsoft.ML.TimeSeries/MovingAverageTransform.cs +++ b/src/Microsoft.ML.TimeSeries/MovingAverageTransform.cs @@ -94,7 +94,7 @@ public MovingAverageTransform(IHostEnvironment env, ModelLoadContext ctx, IDataV Host.CheckDecode(_weights == null || Utils.Size(_weights) == WindowSize + 1 - _lag); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); Host.Assert(WindowSize >= 1); @@ -107,7 +107,7 @@ public override void Save(ModelSaveContext ctx) // int: _lag // Single[]: _weights - base.Save(ctx); + base.SaveModel(ctx); ctx.Writer.Write(_lag); Host.Assert(_weights == null || Utils.Size(_weights) == WindowSize + 1 - _lag); ctx.Writer.WriteSingleArray(_weights); diff --git a/src/Microsoft.ML.TimeSeries/PValueTransform.cs b/src/Microsoft.ML.TimeSeries/PValueTransform.cs index d8f3adcb29..b22d7fcc0c 100644 --- a/src/Microsoft.ML.TimeSeries/PValueTransform.cs +++ b/src/Microsoft.ML.TimeSeries/PValueTransform.cs @@ -91,7 +91,7 @@ public PValueTransform(IHostEnvironment env, ModelLoadContext ctx, IDataView inp Host.CheckDecode(WindowSize >= 1); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); Host.Assert(WindowSize >= 1); @@ -103,7 +103,7 @@ public override void Save(ModelSaveContext ctx) // int: _percentile // byte: _isPositiveSide - base.Save(ctx); + base.SaveModel(ctx); ctx.Writer.Write(_seed); ctx.Writer.WriteBoolByte(_isPositiveSide); } diff --git a/src/Microsoft.ML.TimeSeries/PercentileThresholdTransform.cs b/src/Microsoft.ML.TimeSeries/PercentileThresholdTransform.cs index 868e446899..18a12ed3be 100644 --- a/src/Microsoft.ML.TimeSeries/PercentileThresholdTransform.cs +++ b/src/Microsoft.ML.TimeSeries/PercentileThresholdTransform.cs @@ -86,7 +86,7 @@ public PercentileThresholdTransform(IHostEnvironment env, ModelLoadContext ctx, Host.CheckDecode(MinPercentile <= _percentile && _percentile <= MaxPercentile); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); Host.Assert(MinPercentile <= _percentile && _percentile <= MaxPercentile); @@ -98,7 +98,7 @@ public override void Save(ModelSaveContext ctx) // // Double: _percentile - base.Save(ctx); + base.SaveModel(ctx); ctx.Writer.Write(_percentile); } diff --git a/src/Microsoft.ML.TimeSeries/SequenceModelerBase.cs b/src/Microsoft.ML.TimeSeries/SequenceModelerBase.cs index d66b58c9f8..ca113b9007 100644 --- a/src/Microsoft.ML.TimeSeries/SequenceModelerBase.cs +++ b/src/Microsoft.ML.TimeSeries/SequenceModelerBase.cs @@ -75,6 +75,8 @@ private protected SequenceModelerBase() /// /// Implementation of . /// - public abstract void Save(ModelSaveContext ctx); + void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx); + + private protected abstract void SaveModel(ModelSaveContext ctx); } } diff --git a/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs b/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs index 91a554b96d..e0d722a7a6 100644 --- a/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs +++ b/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs @@ -230,7 +230,7 @@ private protected SequentialAnomalyDetectionTransformBase(IHostEnvironment env, _outputLength = GetOutputLength(ThresholdScore, Host); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -252,7 +252,7 @@ public override void Save(ModelSaveContext ctx) // Double: _powerMartingaleEpsilon // Double: _alertThreshold - base.Save(ctx); + base.SaveModel(ctx); ctx.Writer.Write((byte)Martingale); ctx.Writer.Write((byte)ThresholdScore); ctx.Writer.Write((byte)Side); @@ -623,7 +623,7 @@ public Func GetDependencies(Func activeOutput) return col => false; } - public void Save(ModelSaveContext ctx) => _parent.Save(ctx); + void ICanSaveModel.Save(ModelSaveContext ctx) => _parent.SaveModel(ctx); public Delegate[] CreateGetters(Row input, Func activeOutput, out Action disposer) { diff --git a/src/Microsoft.ML.TimeSeries/SequentialTransformBase.cs b/src/Microsoft.ML.TimeSeries/SequentialTransformBase.cs index 79e2c81846..a13baf46ae 100644 --- a/src/Microsoft.ML.TimeSeries/SequentialTransformBase.cs +++ b/src/Microsoft.ML.TimeSeries/SequentialTransformBase.cs @@ -302,7 +302,7 @@ private protected SequentialTransformBase(IHostEnvironment env, ModelLoadContext _transform = CreateLambdaTransform(Host, input, OutputColumnName, InputColumnName, InitFunction, WindowSize > 0, ct); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); Host.Assert(InitialWindowSize >= 0); diff --git a/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs b/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs index dd25afe637..31a0b57ea4 100644 --- a/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs +++ b/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs @@ -25,7 +25,7 @@ namespace Microsoft.ML.TimeSeriesProcessing /// The input type of the sequential processing. /// The dst type of the sequential processing. /// The state type of the sequential processing. Must be a class inherited from StateBase - public abstract class SequentialTransformerBase : IStatefulTransformer, ICanSaveModel + public abstract class SequentialTransformerBase : IStatefulTransformer where TState : SequentialTransformerBase.StateBase, new() { /// @@ -320,7 +320,9 @@ private protected SequentialTransformerBase(IHost host, ModelLoadContext ctx) OutputColumnType = bs.LoadTypeDescriptionOrNull(ctx.Reader.BaseStream); } - public virtual void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx); + + private protected virtual void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); Host.Assert(InitialWindowSize >= 0); @@ -450,9 +452,9 @@ protected override RowCursor GetRowCursorCore(IEnumerable columns public override RowCursor[] GetRowCursorSet(IEnumerable columnsNeeded, int n, Random rand = null) => new RowCursor[] { GetRowCursorCore(columnsNeeded, rand) }; - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { - _parent.Save(ctx); + (_parent as ICanSaveModel).Save(ctx); } IDataTransform ITransformTemplate.ApplyToData(IHostEnvironment env, IDataView newSource) @@ -637,7 +639,7 @@ public static TimeSeriesRowToRowMapperTransform Create(IHostEnvironment env, Mod return h.Apply("Loading Model", ch => new TimeSeriesRowToRowMapperTransform(h, ctx, input)); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.TimeSeries/SlidingWindowTransform.cs b/src/Microsoft.ML.TimeSeries/SlidingWindowTransform.cs index 96cddd0207..ca94f2c788 100644 --- a/src/Microsoft.ML.TimeSeries/SlidingWindowTransform.cs +++ b/src/Microsoft.ML.TimeSeries/SlidingWindowTransform.cs @@ -49,7 +49,7 @@ public SlidingWindowTransform(IHostEnvironment env, ModelLoadContext ctx, IDataV // } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -57,7 +57,7 @@ public override void Save(ModelSaveContext ctx) // *** Binary format *** // - base.Save(ctx); + base.SaveModel(ctx); } } } diff --git a/src/Microsoft.ML.TimeSeries/SlidingWindowTransformBase.cs b/src/Microsoft.ML.TimeSeries/SlidingWindowTransformBase.cs index 1555c98afa..05b01a0909 100644 --- a/src/Microsoft.ML.TimeSeries/SlidingWindowTransformBase.cs +++ b/src/Microsoft.ML.TimeSeries/SlidingWindowTransformBase.cs @@ -110,7 +110,7 @@ private TInput GetNaValue() return nanValue; } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); Host.Assert(WindowSize >= 1); @@ -123,7 +123,7 @@ public override void Save(ModelSaveContext ctx) // Int32 lag // byte begin - base.Save(ctx); + base.SaveModel(ctx); ctx.Writer.Write(_lag); ctx.Writer.Write((byte)_begin); } diff --git a/src/Microsoft.ML.TimeSeries/SsaAnomalyDetectionBase.cs b/src/Microsoft.ML.TimeSeries/SsaAnomalyDetectionBase.cs index eac9a8ebeb..857b21bcdf 100644 --- a/src/Microsoft.ML.TimeSeries/SsaAnomalyDetectionBase.cs +++ b/src/Microsoft.ML.TimeSeries/SsaAnomalyDetectionBase.cs @@ -176,7 +176,7 @@ public override Schema GetOutputSchema(Schema inputSchema) return Transform(new EmptyDataView(Host, inputSchema)).Schema; } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -196,7 +196,7 @@ public override void Save(ModelSaveContext ctx) // State: StateRef // AdaptiveSingularSpectrumSequenceModeler: _model - base.Save(ctx); + base.SaveModel(ctx); ctx.Writer.Write(SeasonalWindowSize); ctx.Writer.Write(DiscountFactor); ctx.Writer.Write((byte)ErrorFunction); diff --git a/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs b/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs index 3f5e3cb009..a90029b4a6 100644 --- a/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs +++ b/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs @@ -183,7 +183,7 @@ internal SsaChangePointDetector(IHostEnvironment env, ModelLoadContext ctx) Host.CheckDecode(IsAdaptive == false); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -197,7 +197,7 @@ public override void Save(ModelSaveContext ctx) // *** Binary format *** // - base.Save(ctx); + base.SaveModel(ctx); } // Factory method for SignatureLoadRowMapper. diff --git a/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs b/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs index 02b81458b3..5c1af6ccd8 100644 --- a/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs +++ b/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs @@ -165,7 +165,7 @@ internal SsaSpikeDetector(IHostEnvironment env, ModelLoadContext ctx) Host.CheckDecode(IsAdaptive == false); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -178,7 +178,7 @@ public override void Save(ModelSaveContext ctx) // *** Binary format *** // - base.Save(ctx); + base.SaveModel(ctx); } // Factory method for SignatureLoadRowMapper. diff --git a/src/Microsoft.ML.Transforms/CustomMappingTransformer.cs b/src/Microsoft.ML.Transforms/CustomMappingTransformer.cs index e639ad7a6d..609d9c842f 100644 --- a/src/Microsoft.ML.Transforms/CustomMappingTransformer.cs +++ b/src/Microsoft.ML.Transforms/CustomMappingTransformer.cs @@ -19,7 +19,7 @@ namespace Microsoft.ML.Transforms /// /// The type that describes what 'source' columns are consumed from the input . /// The type that describes what new columns are added by this transform. - public sealed class CustomMappingTransformer : ITransformer, ICanSaveModel + public sealed class CustomMappingTransformer : ITransformer where TSrc : class, new() where TDst : class, new() { @@ -60,7 +60,8 @@ public CustomMappingTransformer(IHostEnvironment env, Action mapActi AddedSchema = outSchema; } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx); + internal void SaveModel(ModelSaveContext ctx) { if (_contractName == null) throw _host.Except("Empty contract name for a transform: the transform cannot be saved"); @@ -174,8 +175,8 @@ Schema.DetachedColumn[] IRowMapper.GetOutputColumns() return Enumerable.Range(0, dstRow.Schema.Count).Select(x => new Schema.DetachedColumn(dstRow.Schema[x])).ToArray(); } - public void Save(ModelSaveContext ctx) - => _parent.Save(ctx); + void ICanSaveModel.Save(ModelSaveContext ctx) + => _parent.SaveModel(ctx); public ITransformer GetTransformer() { diff --git a/src/Microsoft.ML.Transforms/FourierDistributionSampler.cs b/src/Microsoft.ML.Transforms/FourierDistributionSampler.cs index af04cd10d5..40915c484c 100644 --- a/src/Microsoft.ML.Transforms/FourierDistributionSampler.cs +++ b/src/Microsoft.ML.Transforms/FourierDistributionSampler.cs @@ -104,7 +104,7 @@ private GaussianFourierSampler(IHostEnvironment env, ModelLoadContext ctx) _host.CheckDecode(FloatUtils.IsFinite(_gamma)); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { ctx.SetVersionInfo(GetVersionInfo()); @@ -185,7 +185,7 @@ private LaplacianFourierSampler(IHostEnvironment env, ModelLoadContext ctx) _host.CheckDecode(FloatUtils.IsFinite(_a)); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { ctx.SetVersionInfo(GetVersionInfo()); diff --git a/src/Microsoft.ML.Transforms/GcnTransform.cs b/src/Microsoft.ML.Transforms/GcnTransform.cs index f48a4329b7..5c54d5eccc 100644 --- a/src/Microsoft.ML.Transforms/GcnTransform.cs +++ b/src/Microsoft.ML.Transforms/GcnTransform.cs @@ -306,7 +306,7 @@ private LpNormalizingTransformer(IHost host, ModelLoadContext ctx) _columns[i] = new ColumnInfoLoaded(ctx, ColumnPairs[i].outputColumnName, ColumnPairs[i].inputColumnName, ctx.Header.ModelVerWritten >= VerVectorNormalizerSupported); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); diff --git a/src/Microsoft.ML.Transforms/GroupTransform.cs b/src/Microsoft.ML.Transforms/GroupTransform.cs index e36a2c4c4a..4c4ac807da 100644 --- a/src/Microsoft.ML.Transforms/GroupTransform.cs +++ b/src/Microsoft.ML.Transforms/GroupTransform.cs @@ -137,7 +137,7 @@ private GroupTransform(IHost host, ModelLoadContext ctx, IDataView input) _groupBinding = new GroupBinding(input.Schema, host, ctx); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -305,7 +305,7 @@ private Schema BuildOutputSchema(Schema sourceSchema) return schemaBuilder.GetSchema(); } - public void Save(ModelSaveContext ctx) + internal void Save(ModelSaveContext ctx) { _ectx.AssertValue(ctx); diff --git a/src/Microsoft.ML.Transforms/HashJoiningTransform.cs b/src/Microsoft.ML.Transforms/HashJoiningTransform.cs index b0a4edbea8..61f2e86adb 100644 --- a/src/Microsoft.ML.Transforms/HashJoiningTransform.cs +++ b/src/Microsoft.ML.Transforms/HashJoiningTransform.cs @@ -285,7 +285,7 @@ public static HashJoiningTransform Create(IHostEnvironment env, ModelLoadContext return h.Apply("Loading Model", ch => new HashJoiningTransform(h, ctx, input)); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Transforms/KeyToVectorMapping.cs b/src/Microsoft.ML.Transforms/KeyToVectorMapping.cs index 4364ba6097..036287d3ed 100644 --- a/src/Microsoft.ML.Transforms/KeyToVectorMapping.cs +++ b/src/Microsoft.ML.Transforms/KeyToVectorMapping.cs @@ -83,7 +83,7 @@ internal KeyToBinaryVectorMappingTransformer(IHostEnvironment env, params (strin { } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); diff --git a/src/Microsoft.ML.Transforms/LambdaTransform.cs b/src/Microsoft.ML.Transforms/LambdaTransform.cs index 2907a41941..c823ae3999 100644 --- a/src/Microsoft.ML.Transforms/LambdaTransform.cs +++ b/src/Microsoft.ML.Transforms/LambdaTransform.cs @@ -204,7 +204,7 @@ public static ITransformTemplate CreateFilter(IHostEnvironment env /// * a custom save action that serializes the transform 'state' to the binary writer. /// * a custom load action that de-serializes the transform from the binary reader. This must be a public static method of a public class. /// - internal abstract class LambdaTransformBase + internal abstract class LambdaTransformBase : ICanSaveModel { private readonly Action _saveAction; private readonly byte[] _loadFuncBytes; @@ -248,7 +248,7 @@ protected LambdaTransformBase(IHostEnvironment env, string name, LambdaTransform AssertConsistentSerializable(); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); Host.Check(CanSave(), "Cannot save this transform as it was not specified as being savable"); diff --git a/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs index a5813aa392..84593a7805 100644 --- a/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs +++ b/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs @@ -135,7 +135,7 @@ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Sch /// /// Saves the transform. /// - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransform.cs b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransform.cs index 4cfbddf91d..74e54e896e 100644 --- a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransform.cs +++ b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransform.cs @@ -116,7 +116,7 @@ public static MissingValueIndicatorTransform Create(IHostEnvironment env, ModelL }); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs index eb91e3eeed..31d775dc68 100644 --- a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs +++ b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs @@ -131,7 +131,7 @@ internal static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Sc /// /// Saves the transform. /// - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Transforms/MissingValueReplacing.cs b/src/Microsoft.ML.Transforms/MissingValueReplacing.cs index b99676cbc4..21c0054fff 100644 --- a/src/Microsoft.ML.Transforms/MissingValueReplacing.cs +++ b/src/Microsoft.ML.Transforms/MissingValueReplacing.cs @@ -492,7 +492,7 @@ private void WriteTypeAndValue(Stream stream, BinarySaver saver, ColumnType t throw Host.Except("We do not know how to serialize terms of type '{0}'", type); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); diff --git a/src/Microsoft.ML.Transforms/OneHotEncoding.cs b/src/Microsoft.ML.Transforms/OneHotEncoding.cs index 31505972ed..bc9bec1555 100644 --- a/src/Microsoft.ML.Transforms/OneHotEncoding.cs +++ b/src/Microsoft.ML.Transforms/OneHotEncoding.cs @@ -26,7 +26,7 @@ namespace Microsoft.ML.Transforms.Categorical { /// - public sealed class OneHotEncodingTransformer : ITransformer, ICanSaveModel + public sealed class OneHotEncodingTransformer : ITransformer { public enum OutputKind : byte { @@ -165,7 +165,7 @@ internal OneHotEncodingTransformer(ValueToKeyMappingEstimator term, IEstimator _transformer.Transform(input); - public void Save(ModelSaveContext ctx) => _transformer.Save(ctx); + void ICanSaveModel.Save(ModelSaveContext ctx) => (_transformer as ICanSaveModel).Save(ctx); public bool IsRowToRowMapper => _transformer.IsRowToRowMapper; diff --git a/src/Microsoft.ML.Transforms/OneHotHashEncoding.cs b/src/Microsoft.ML.Transforms/OneHotHashEncoding.cs index 6c61b29762..df693a9c1b 100644 --- a/src/Microsoft.ML.Transforms/OneHotHashEncoding.cs +++ b/src/Microsoft.ML.Transforms/OneHotHashEncoding.cs @@ -24,7 +24,7 @@ namespace Microsoft.ML.Transforms.Categorical /// /// Produces a column of indicator vectors. The mapping between a value and a corresponding index is done through hashing. /// - public sealed class OneHotHashEncodingTransformer : ITransformer, ICanSaveModel + public sealed class OneHotHashEncodingTransformer : ITransformer { internal sealed class Column : OneToOneColumn { @@ -189,7 +189,7 @@ internal OneHotHashEncodingTransformer(HashingEstimator hash, IEstimator public IDataView Transform(IDataView input) => _transformer.Transform(input); - public void Save(ModelSaveContext ctx) => _transformer.Save(ctx); + void ICanSaveModel.Save(ModelSaveContext ctx) => (_transformer as ICanSaveModel).Save(ctx); /// /// Whether a call to should succeed, on an appropriate schema. diff --git a/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs b/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs index 3927ae6daf..c0da05f872 100644 --- a/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs +++ b/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs @@ -280,7 +280,7 @@ public static OptionalColumnTransform Create(IHostEnvironment env, ModelLoadCont return h.Apply("Loading Model", ch => new OptionalColumnTransform(h, ctx, input)); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Transforms/ProduceIdTransform.cs b/src/Microsoft.ML.Transforms/ProduceIdTransform.cs index 63abdc723f..91ac258ace 100644 --- a/src/Microsoft.ML.Transforms/ProduceIdTransform.cs +++ b/src/Microsoft.ML.Transforms/ProduceIdTransform.cs @@ -27,7 +27,7 @@ namespace Microsoft.ML.Transforms /// some other file, then apply this transform to that dataview, it may of course have a different /// result. This is distinct from most transforms that produce results based on data alone. /// - public sealed class ProduceIdTransform : RowToRowTransformBase + internal sealed class ProduceIdTransform : RowToRowTransformBase { public sealed class Arguments { @@ -60,7 +60,7 @@ public static Bindings Create(ModelLoadContext ctx, Schema input) return new Bindings(input, true, name); } - public void Save(ModelSaveContext ctx) + internal void Save(ModelSaveContext ctx) { Contracts.AssertValue(ctx); @@ -127,7 +127,7 @@ public static ProduceIdTransform Create(IHostEnvironment env, ModelLoadContext c return h.Apply("Loading Model", ch => new ProduceIdTransform(h, ctx, input)); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs b/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs index 754ac9b573..171e8a6427 100644 --- a/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs +++ b/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs @@ -162,7 +162,7 @@ public TransformInfo(IHostEnvironment env, ModelLoadContext ctx, string director InitializeFourierCoefficients(roundedUpNumFeatures, roundedUpD); } - public void Save(ModelSaveContext ctx, string directoryName) + internal void Save(ModelSaveContext ctx, string directoryName) { Contracts.AssertValue(ctx); @@ -463,7 +463,7 @@ private static RandomFourierFeaturizingTransformer Create(IHostEnvironment env, return new RandomFourierFeaturizingTransformer(host, ctx); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index 7141d697ec..f759edd697 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -340,7 +340,7 @@ internal LdaSummary GetLdaSummary(VBuffer> mapping) } } - public void Save(ModelSaveContext ctx) + internal void Save(ModelSaveContext ctx) { Contracts.AssertValue(ctx); long memBlockSize = 0; @@ -733,7 +733,7 @@ private static LatentDirichletAllocationTransformer Create(IHostEnvironment env, }); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs b/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs index 7d01048a44..84e95b54d4 100644 --- a/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs +++ b/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs @@ -256,7 +256,7 @@ internal NgramHashingTransformer(IHostEnvironment env, IDataView input, params N } } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -605,7 +605,7 @@ private protected override Func GetDependenciesCore(Func a return col => active[col]; } - public override void Save(ModelSaveContext ctx) => _parent.Save(ctx); + private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx); protected override Schema.DetachedColumn[] GetOutputColumnsCore() { diff --git a/src/Microsoft.ML.Transforms/Text/NgramTransform.cs b/src/Microsoft.ML.Transforms/Text/NgramTransform.cs index b32a96aea9..acb31ed104 100644 --- a/src/Microsoft.ML.Transforms/Text/NgramTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/NgramTransform.cs @@ -162,7 +162,7 @@ public TransformInfo(ModelLoadContext ctx, bool readWeighting) NonEmptyLevels = ctx.Reader.ReadBoolArray(NgramLength); } - public void Save(ModelSaveContext ctx) + internal void Save(ModelSaveContext ctx) { Contracts.AssertValue(ctx); @@ -451,7 +451,7 @@ private static NgramExtractingTransformer Create(IHostEnvironment env, ModelLoad return new NgramExtractingTransformer(host, ctx); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs b/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs index f633c711b4..46a2b56a78 100644 --- a/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs +++ b/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs @@ -203,7 +203,7 @@ internal StopWordsRemovingTransformer(IHostEnvironment env, params StopWordsRemo _columns = columns; } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -477,7 +477,7 @@ private protected override Func GetDependenciesCore(Func a return col => active[col]; } - public override void Save(ModelSaveContext ctx) => _parent.Save(ctx); + private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx); } } @@ -850,7 +850,7 @@ internal CustomStopWordsRemovingTransformer(IHostEnvironment env, string stopwor LoadStopWords(ch, stopwords.AsMemory(), dataFile, stopwordsColumn, loader, out _stopWordsMap); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs b/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs index 985c6eb6fa..2ea6f6c071 100644 --- a/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs +++ b/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs @@ -562,7 +562,7 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat return estimator.Fit(data).Transform(data) as IDataTransform; } - private sealed class Transformer : ITransformer, ICanSaveModel + private sealed class Transformer : ITransformer { private const string TransformDirTemplate = "Step_{0:000}"; @@ -607,7 +607,7 @@ public IRowToRowMapper GetRowToRowMapper(Schema inputSchema) return new CompositeRowToRowMapper(inputSchema, revMaps.ToArray()); } - public void Save(ModelSaveContext ctx) + void ICanSaveModel.Save(ModelSaveContext ctx) { _host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs b/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs index 75bab0f8b4..e5764fafd8 100644 --- a/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs +++ b/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs @@ -118,7 +118,7 @@ protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].inputColumnName, TextNormalizingEstimator.ExpectedColumnType, type.ToString()); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs b/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs index 8ef8109c52..bd23d4396e 100644 --- a/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs +++ b/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs @@ -136,7 +136,7 @@ private TokenizingByCharactersTransformer(IHost host, ModelLoadContext ctx) : private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) => Create(env, ctx).MakeDataTransform(input); - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs index 8a6c8108b5..42fa72033e 100644 --- a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs +++ b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs @@ -284,7 +284,7 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema) => Create(env, ctx).MakeRowMapper(inputSchema); - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs b/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs index 1d46ed420e..bf05a60e84 100644 --- a/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs +++ b/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs @@ -150,7 +150,7 @@ private WordTokenizingTransformer(IHost host, ModelLoadContext ctx) : private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) => Create(env, ctx).MakeDataTransform(input); - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); diff --git a/src/Microsoft.ML.Transforms/UngroupTransform.cs b/src/Microsoft.ML.Transforms/UngroupTransform.cs index 2dd0112dab..e0a594d8d6 100644 --- a/src/Microsoft.ML.Transforms/UngroupTransform.cs +++ b/src/Microsoft.ML.Transforms/UngroupTransform.cs @@ -138,7 +138,7 @@ private UngroupTransform(IHost host, ModelLoadContext ctx, IDataView input) _ungroupBinding = UngroupBinding.Create(ctx, host, input.Schema); } - public override void Save(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); @@ -365,7 +365,7 @@ public static UngroupBinding Create(ModelLoadContext ctx, IExceptionContext ectx return new UngroupBinding(ectx, inputSchema, mode, pivotColumns); } - public void Save(ModelSaveContext ctx) + internal void Save(ModelSaveContext ctx) { _ectx.AssertValue(ctx); diff --git a/test/Microsoft.ML.CodeAnalyzer.Tests/Helpers/DiagnosticVerifier.cs b/test/Microsoft.ML.CodeAnalyzer.Tests/Helpers/DiagnosticVerifier.cs index 7bf7e5693f..248a4645f5 100644 --- a/test/Microsoft.ML.CodeAnalyzer.Tests/Helpers/DiagnosticVerifier.cs +++ b/test/Microsoft.ML.CodeAnalyzer.Tests/Helpers/DiagnosticVerifier.cs @@ -15,6 +15,7 @@ using Microsoft.Data.DataView; using Microsoft.ML.Data; using Microsoft.ML.StaticPipe; +using Microsoft.ML.Transforms.Conversions; using Xunit; namespace Microsoft.ML.CodeAnalyzer.Tests.Helpers @@ -267,7 +268,7 @@ private static string FormatDiagnostics(DiagnosticAnalyzer analyzer, params Diag private static readonly MetadataReference MSDataDataViewReference = RefFromType(); private static readonly MetadataReference MLNetCoreReference = RefFromType(); - private static readonly MetadataReference MLNetDataReference = RefFromType(); + private static readonly MetadataReference MLNetDataReference = RefFromType(); private static readonly MetadataReference MLNetStaticPipeReference = RefFromType(); protected static MetadataReference RefFromType()