From 6c2cea5ac8a7b6da44fcc98048c58d7adfb0e46f Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Mon, 4 Mar 2019 14:50:11 -0800 Subject: [PATCH 1/3] APIs corresponding to SignatureLoadModel should not create objects generic on IPredictorProducing<>. --- .../EntryPoints/PredictorModelImpl.cs | 2 +- .../EntryPoints/SummarizePredictor.cs | 2 +- .../Prediction/Calibrator.cs | 40 ++++++++++--------- .../Scorers/PredictionTransformer.cs | 37 +++++++++-------- .../EntryPoints/PipelineEnsemble.cs | 2 +- .../TreeEnsemble/TreeEnsembleCombiner.cs | 6 +-- .../TreeEnsembleFeaturizer.cs | 2 +- .../Standard/SdcaBinary.cs | 2 +- .../Standard/StochasticTrainerBase.cs | 2 +- 9 files changed, 48 insertions(+), 47 deletions(-) diff --git a/src/Microsoft.ML.Data/EntryPoints/PredictorModelImpl.cs b/src/Microsoft.ML.Data/EntryPoints/PredictorModelImpl.cs index 97559e6d6e..9422995411 100644 --- a/src/Microsoft.ML.Data/EntryPoints/PredictorModelImpl.cs +++ b/src/Microsoft.ML.Data/EntryPoints/PredictorModelImpl.cs @@ -114,7 +114,7 @@ internal override string[] GetLabelInfo(IHostEnvironment env, out DataViewType l var calibrated = predictor as IWeaklyTypedCalibratedModelParameters; while (calibrated != null) { - predictor = calibrated.WeeklyTypedSubModel; + predictor = calibrated.WeaklyTypedSubModel; calibrated = predictor as IWeaklyTypedCalibratedModelParameters; } var canGetTrainingLabelNames = predictor as ICanGetTrainingLabelNames; diff --git a/src/Microsoft.ML.Data/EntryPoints/SummarizePredictor.cs b/src/Microsoft.ML.Data/EntryPoints/SummarizePredictor.cs index e2d15bd809..aa0d2e0032 100644 --- a/src/Microsoft.ML.Data/EntryPoints/SummarizePredictor.cs +++ b/src/Microsoft.ML.Data/EntryPoints/SummarizePredictor.cs @@ -51,7 +51,7 @@ internal static IDataView GetSummaryAndStats(IHostEnvironment env, IPredictor pr var calibrated = predictor as IWeaklyTypedCalibratedModelParameters; while (calibrated != null) { - predictor = calibrated.WeeklyTypedSubModel; + predictor = calibrated.WeaklyTypedSubModel; calibrated = predictor as IWeaklyTypedCalibratedModelParameters; } diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index 46e8063df0..e461d8085e 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -14,7 +14,6 @@ using Microsoft.ML.CommandLine; using Microsoft.ML.Data; using Microsoft.ML.EntryPoints; -using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Model; using Microsoft.ML.Model.OnnxConverter; @@ -58,19 +57,19 @@ "Naive Calibration Executor", NaiveCalibrator.LoaderSignature)] -[assembly: LoadableClass(typeof(ValueMapperCalibratedModelParameters, ICalibrator>), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(ValueMapperCalibratedModelParameters), null, typeof(SignatureLoadModel), "Calibrated Predictor Executor", ValueMapperCalibratedModelParameters, ICalibrator>.LoaderSignature, "BulkCaliPredExec")] -[assembly: LoadableClass(typeof(FeatureWeightsCalibratedModelParameters, ICalibrator>), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(FeatureWeightsCalibratedModelParameters), null, typeof(SignatureLoadModel), "Feature Weights Calibrated Predictor Executor", FeatureWeightsCalibratedModelParameters, ICalibrator>.LoaderSignature)] -[assembly: LoadableClass(typeof(ParameterMixingCalibratedModelParameters, ICalibrator>), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(ParameterMixingCalibratedModelParameters), null, typeof(SignatureLoadModel), "Parameter Mixing Calibrated Predictor Executor", ParameterMixingCalibratedModelParameters, ICalibrator>.LoaderSignature)] -[assembly: LoadableClass(typeof(SchemaBindableCalibratedModelParameters, ICalibrator>), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(SchemaBindableCalibratedModelParameters), null, typeof(SignatureLoadModel), "Schema Bindable Calibrated Predictor", SchemaBindableCalibratedModelParameters, ICalibrator>.LoaderSignature)] [assembly: LoadableClass(typeof(void), typeof(Calibrate), null, typeof(SignatureEntryPointModule), "Calibrate")] @@ -147,8 +146,8 @@ internal interface ISelfCalibratingPredictor [BestFriend] internal interface IWeaklyTypedCalibratedModelParameters { - IPredictorProducing WeeklyTypedSubModel { get; } - ICalibrator WeeklyTypedCalibrator { get; } + IPredictorProducing WeaklyTypedSubModel { get; } + ICalibrator WeaklyTypedCalibrator { get; } } /// @@ -186,8 +185,8 @@ public abstract class CalibratedModelParametersBase : public TCalibrator Calibrator { get; } // Type-unsafed accessors of strongly-typed members. - IPredictorProducing IWeaklyTypedCalibratedModelParameters.WeeklyTypedSubModel => (IPredictorProducing)SubModel; - ICalibrator IWeaklyTypedCalibratedModelParameters.WeeklyTypedCalibrator => Calibrator; + IPredictorProducing IWeaklyTypedCalibratedModelParameters.WeaklyTypedSubModel => (IPredictorProducing)SubModel; + ICalibrator IWeaklyTypedCalibratedModelParameters.WeaklyTypedCalibrator => Calibrator; PredictionKind IPredictor.PredictionKind => ((IPredictorProducing)SubModel).PredictionKind; @@ -198,6 +197,7 @@ private protected CalibratedModelParametersBase(IHostEnvironment env, string nam Host = env.Register(name); Host.CheckValue(predictor, nameof(predictor)); Host.CheckValue(calibrator, nameof(calibrator)); + Host.Assert(predictor is IPredictorProducing); SubModel = predictor; Calibrator = calibrator; @@ -270,7 +270,7 @@ internal abstract class ValueMapperCalibratedModelParametersBase, IValueMapperDist, IFeatureContributionMapper, ICalculateFeatureContribution, IDistCanSavePfa, IDistCanSaveOnnx - where TSubModel : class, IPredictorProducing + where TSubModel : class where TCalibrator : class, ICalibrator { private readonly IValueMapper _mapper; @@ -380,7 +380,7 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string [BestFriend] internal sealed class ValueMapperCalibratedModelParameters : ValueMapperCalibratedModelParametersBase, ICanSaveModel - where TSubModel : class, IPredictorProducing + where TSubModel : class where TCalibrator : class, ICalibrator { internal ValueMapperCalibratedModelParameters(IHostEnvironment env, TSubModel predictor, TCalibrator calibrator) @@ -442,7 +442,7 @@ internal sealed class FeatureWeightsCalibratedModelParameters, IPredictorWithFeatureWeights, ICanSaveModel - where TSubModel : class, IPredictorWithFeatureWeights + where TSubModel : class where TCalibrator : class, ICalibrator { private readonly IPredictorWithFeatureWeights _featureWeights; @@ -451,7 +451,8 @@ internal FeatureWeightsCalibratedModelParameters(IHostEnvironment env, TSubModel TCalibrator calibrator) : base(env, RegistrationName, predictor, calibrator) { - _featureWeights = predictor; + Host.Assert(predictor is IPredictorWithFeatureWeights); + _featureWeights = predictor as IPredictorWithFeatureWeights; } internal const string LoaderSignature = "FeatWCaliPredExec"; @@ -506,7 +507,7 @@ internal sealed class ParameterMixingCalibratedModelParameters, IPredictorWithFeatureWeights, ICanSaveModel - where TSubModel : class, IPredictorWithFeatureWeights + where TSubModel : class where TCalibrator : class, ICalibrator { private readonly IPredictorWithFeatureWeights _featureWeights; @@ -516,7 +517,8 @@ internal ParameterMixingCalibratedModelParameters(IHostEnvironment env, TSubMode { Host.Check(predictor is IParameterMixer, "Predictor does not implement " + nameof(IParameterMixer)); Host.Check(calibrator is IParameterMixer, "Calibrator does not implement " + nameof(IParameterMixer)); - _featureWeights = predictor; + Host.Assert(predictor is IPredictorWithFeatureWeights); + _featureWeights = predictor as IPredictorWithFeatureWeights; } internal const string LoaderSignature = "PMixCaliPredExec"; @@ -538,7 +540,7 @@ private ParameterMixingCalibratedModelParameters(IHostEnvironment env, ModelLoad { Host.Check(SubModel is IParameterMixer, "Predictor does not implement " + nameof(IParameterMixer)); Host.Check(SubModel is IPredictorWithFeatureWeights, "Predictor does not implement " + nameof(IPredictorWithFeatureWeights)); - _featureWeights = SubModel; + _featureWeights = SubModel as IPredictorWithFeatureWeights; } private static ParameterMixingCalibratedModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) @@ -587,7 +589,7 @@ IParameterMixer IParameterMixer.CombineParameters(IList : CalibratedModelParametersBase, ISchemaBindableMapper, ICanSaveModel, IBindableCanSavePfa, IBindableCanSaveOnnx, IFeatureContributionMapper - where TSubModel : class, IPredictorProducing + where TSubModel : class where TCalibrator : class, ICalibrator { private sealed class Bound : ISchemaBoundRowMapper @@ -702,14 +704,14 @@ private static VersionInfo GetVersionInfo() internal SchemaBindableCalibratedModelParameters(IHostEnvironment env, TSubModel predictor, TCalibrator calibrator) : base(env, LoaderSignature, predictor, calibrator) { - _bindable = ScoreUtils.GetSchemaBindableMapper(Host, SubModel); + _bindable = ScoreUtils.GetSchemaBindableMapper(Host, SubModel as IPredictorProducing); _featureContribution = SubModel as IFeatureContributionMapper; } private SchemaBindableCalibratedModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, LoaderSignature, GetPredictor(env, ctx), GetCalibrator(env, ctx)) { - _bindable = ScoreUtils.GetSchemaBindableMapper(Host, SubModel); + _bindable = ScoreUtils.GetSchemaBindableMapper(Host, SubModel as IPredictorProducing); _featureContribution = SubModel as IFeatureContributionMapper; } diff --git a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs index 7b94c92795..7f5376feba 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs @@ -7,24 +7,23 @@ using Microsoft.ML; using Microsoft.ML.Data; using Microsoft.ML.Data.IO; -using Microsoft.ML.Model; -[assembly: LoadableClass(typeof(BinaryPredictionTransformer>), typeof(BinaryPredictionTransformer), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(BinaryPredictionTransformer), typeof(BinaryPredictionTransformer), null, typeof(SignatureLoadModel), "", BinaryPredictionTransformer.LoaderSignature)] -[assembly: LoadableClass(typeof(MulticlassPredictionTransformer>>), typeof(MulticlassPredictionTransformer), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(MulticlassPredictionTransformer), typeof(MulticlassPredictionTransformer), null, typeof(SignatureLoadModel), "", MulticlassPredictionTransformer.LoaderSignature)] -[assembly: LoadableClass(typeof(RegressionPredictionTransformer>), typeof(RegressionPredictionTransformer), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(RegressionPredictionTransformer), typeof(RegressionPredictionTransformer), null, typeof(SignatureLoadModel), "", RegressionPredictionTransformer.LoaderSignature)] -[assembly: LoadableClass(typeof(RankingPredictionTransformer>), typeof(RankingPredictionTransformer), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(RankingPredictionTransformer), typeof(RankingPredictionTransformer), null, typeof(SignatureLoadModel), "", RankingPredictionTransformer.LoaderSignature)] -[assembly: LoadableClass(typeof(AnomalyPredictionTransformer>), typeof(AnomalyPredictionTransformer), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(AnomalyPredictionTransformer), typeof(AnomalyPredictionTransformer), null, typeof(SignatureLoadModel), "", AnomalyPredictionTransformer.LoaderSignature)] -[assembly: LoadableClass(typeof(ClusteringPredictionTransformer>>), typeof(ClusteringPredictionTransformer), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(ClusteringPredictionTransformer), typeof(ClusteringPredictionTransformer), null, typeof(SignatureLoadModel), "", ClusteringPredictionTransformer.LoaderSignature)] namespace Microsoft.ML.Data @@ -595,47 +594,47 @@ internal static class BinaryPredictionTransformer { public const string LoaderSignature = "BinaryPredXfer"; - public static BinaryPredictionTransformer> Create(IHostEnvironment env, ModelLoadContext ctx) - => new BinaryPredictionTransformer>(env, ctx); + public static BinaryPredictionTransformer Create(IHostEnvironment env, ModelLoadContext ctx) + => new BinaryPredictionTransformer(env, ctx); } internal static class MulticlassPredictionTransformer { public const string LoaderSignature = "MulticlassPredXfer"; - public static MulticlassPredictionTransformer>> Create(IHostEnvironment env, ModelLoadContext ctx) - => new MulticlassPredictionTransformer>>(env, ctx); + public static MulticlassPredictionTransformer Create(IHostEnvironment env, ModelLoadContext ctx) + => new MulticlassPredictionTransformer(env, ctx); } internal static class RegressionPredictionTransformer { public const string LoaderSignature = "RegressionPredXfer"; - public static RegressionPredictionTransformer> Create(IHostEnvironment env, ModelLoadContext ctx) - => new RegressionPredictionTransformer>(env, ctx); + public static RegressionPredictionTransformer Create(IHostEnvironment env, ModelLoadContext ctx) + => new RegressionPredictionTransformer(env, ctx); } internal static class RankingPredictionTransformer { public const string LoaderSignature = "RankingPredXfer"; - public static RankingPredictionTransformer> Create(IHostEnvironment env, ModelLoadContext ctx) - => new RankingPredictionTransformer>(env, ctx); + public static RankingPredictionTransformer Create(IHostEnvironment env, ModelLoadContext ctx) + => new RankingPredictionTransformer(env, ctx); } internal static class AnomalyPredictionTransformer { public const string LoaderSignature = "AnomalyPredXfer"; - public static AnomalyPredictionTransformer> Create(IHostEnvironment env, ModelLoadContext ctx) - => new AnomalyPredictionTransformer>(env, ctx); + public static AnomalyPredictionTransformer Create(IHostEnvironment env, ModelLoadContext ctx) + => new AnomalyPredictionTransformer(env, ctx); } internal static class ClusteringPredictionTransformer { public const string LoaderSignature = "ClusteringPredXfer"; - public static ClusteringPredictionTransformer>> Create(IHostEnvironment env, ModelLoadContext ctx) - => new ClusteringPredictionTransformer>>(env, ctx); + public static ClusteringPredictionTransformer Create(IHostEnvironment env, ModelLoadContext ctx) + => new ClusteringPredictionTransformer(env, ctx); } } diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/PipelineEnsemble.cs b/src/Microsoft.ML.Ensemble/EntryPoints/PipelineEnsemble.cs index fe65254366..61fedbc848 100644 --- a/src/Microsoft.ML.Ensemble/EntryPoints/PipelineEnsemble.cs +++ b/src/Microsoft.ML.Ensemble/EntryPoints/PipelineEnsemble.cs @@ -38,7 +38,7 @@ public static SummaryOutput Summarize(IHostEnvironment env, SummarizePredictor.I var calibrated = predictor as IWeaklyTypedCalibratedModelParameters; while (calibrated != null) { - predictor = calibrated.WeeklyTypedSubModel; + predictor = calibrated.WeaklyTypedSubModel; calibrated = predictor as IWeaklyTypedCalibratedModelParameters; } var ensemble = predictor as SchemaBindablePipelineEnsembleBase; diff --git a/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsembleCombiner.cs b/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsembleCombiner.cs index 8643058684..fab70dbbb4 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsembleCombiner.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsembleCombiner.cs @@ -51,11 +51,11 @@ IPredictor IModelCombiner.CombineModels(IEnumerable models) var calibrated = predictor as IWeaklyTypedCalibratedModelParameters; double paramA = 1; if (calibrated != null) - _host.Check(calibrated.WeeklyTypedCalibrator is PlattCalibrator, + _host.Check(calibrated.WeaklyTypedCalibrator is PlattCalibrator, "Combining FastTree models can only be done when the models are calibrated with Platt calibrator"); - predictor = calibrated.WeeklyTypedSubModel; - paramA = -((PlattCalibrator)calibrated.WeeklyTypedCalibrator).Slope; + predictor = calibrated.WeaklyTypedSubModel; + paramA = -((PlattCalibrator)calibrated.WeaklyTypedCalibrator).Slope; var tree = predictor as TreeEnsembleModelParameters; diff --git a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs index f58de700c9..8d62f49eea 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs @@ -590,7 +590,7 @@ private static IDataTransform Create(IHostEnvironment env, Arguments args, IData // Make sure that the given predictor has the correct number of input features. if (predictor is IWeaklyTypedCalibratedModelParameters calibrated) - predictor = calibrated.WeeklyTypedSubModel; + predictor = calibrated.WeaklyTypedSubModel; // Predictor should be a TreeEnsembleModelParameters, which implements IValueMapper, so this should // be non-null. var vm = predictor as IValueMapper; diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs index a7b3fbfccb..10bdaea4ea 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs @@ -72,7 +72,7 @@ private protected override TModel TrainModelCore(TrainContext context) { var preparedData = PrepareDataFromTrainingExamples(ch, context.TrainingSet, out int weightSetCount); var initPred = context.InitialPredictor; - var linInitPred = (initPred as IWeaklyTypedCalibratedModelParameters)?.WeeklyTypedSubModel as LinearModelParameters; + var linInitPred = (initPred as IWeaklyTypedCalibratedModelParameters)?.WeaklyTypedSubModel as LinearModelParameters; linInitPred = linInitPred ?? initPred as LinearModelParameters; Host.CheckParam(context.InitialPredictor == null || linInitPred != null, nameof(context), "Initial predictor was not a linear predictor."); diff --git a/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs b/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs index acab62317c..4c858697a1 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs @@ -35,7 +35,7 @@ private protected override TModel TrainModelCore(TrainContext context) var preparedData = PrepareDataFromTrainingExamples(ch, context.TrainingSet, out int weightSetCount); var initPred = context.InitialPredictor; // Try extract linear model from calibrated predictor. - var linInitPred = (initPred as IWeaklyTypedCalibratedModelParameters)?.WeeklyTypedSubModel as LinearModelParameters; + var linInitPred = (initPred as IWeaklyTypedCalibratedModelParameters)?.WeaklyTypedSubModel as LinearModelParameters; // If the initial predictor is not calibrated, it should be a linear model. linInitPred = linInitPred ?? initPred as LinearModelParameters; Host.CheckParam(context.InitialPredictor == null || linInitPred != null, nameof(context), From 7b08c00084419193accb29789c8a3f5441707f8f Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Mon, 4 Mar 2019 14:50:38 -0800 Subject: [PATCH 2/3] Add unit test --- .../Api/Estimators/DeserializationTests.cs | 57 +++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DeserializationTests.cs diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DeserializationTests.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DeserializationTests.cs new file mode 100644 index 0000000000..fc8a507bb5 --- /dev/null +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DeserializationTests.cs @@ -0,0 +1,57 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.IO; +using System.Linq; +using Microsoft.ML.Calibrators; +using Microsoft.ML.Data; +using Microsoft.ML.RunTests; +using Microsoft.ML.Trainers.FastTree; +using Xunit; + +namespace Microsoft.ML.Tests.Scenarios.Api +{ + public partial class ApiScenariosTests + { + private class InputData + { + [LoadColumn(0)] + public float Label { get; set; } + [LoadColumn(9, 14)] + [VectorType(6)] + public float[] Features { get; set; } + } + + [Fact] + public void LoadModelAndExtractPredictor() + { + var ml = new MLContext(seed: 1, conc: 1); + var file = new MultiFileSource(GetDataPath(TestDatasets.adult.trainFilename)); + var loader = ml.Data.CreateTextLoader(hasHeader: true, dataSample: file); + var data = loader.Load(file); + + // Pipeline. + var pipeline = ml.BinaryClassification.Trainers.GeneralizedAdditiveModels(); + + // Train. + var model = pipeline.Fit(data); + + // Save and reload. + string modelPath = GetOutputPath(FullTestName + "-model.zip"); + using (var fs = File.Create(modelPath)) + ml.Model.Save(model, fs); + + ITransformer loadedModel; + using (var fs = File.OpenRead(modelPath)) + loadedModel = ml.Model.Load(fs); + + var gam = (((loadedModel as TransformerChain).LastTransformer + as BinaryPredictionTransformer).Model + as CalibratedModelParametersBase).SubModel + as BinaryClassificationGamModelParameters; + Assert.NotNull(gam); + } + } +} From 2cf7ee4bbafd0faab4b1f44fe5fa1c1d15afe25f Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Tue, 5 Mar 2019 10:30:19 -0800 Subject: [PATCH 3/3] Make IDataLoader inherit from ICanSaveModel, and add methods to save/load it in ModelOperationsCatalog. --- src/Microsoft.ML.Core/Data/IEstimator.cs | 2 +- .../DataLoadSave/CompositeDataLoader.cs | 82 +++++++++++-------- .../DataLoadSave/Text/TextLoader.cs | 6 +- .../DataLoadSave/TransformerChain.cs | 4 +- .../Model/ModelOperationsCatalog.cs | 21 +++++ .../Api/Estimators/DeserializationTests.cs | 57 +++++++++++++ 6 files changed, 135 insertions(+), 37 deletions(-) diff --git a/src/Microsoft.ML.Core/Data/IEstimator.cs b/src/Microsoft.ML.Core/Data/IEstimator.cs index 8c4d9b85a0..1c69771e85 100644 --- a/src/Microsoft.ML.Core/Data/IEstimator.cs +++ b/src/Microsoft.ML.Core/Data/IEstimator.cs @@ -224,7 +224,7 @@ internal bool TryFindColumn(string name, out Column column) /// The 'data loader' takes a certain kind of input and turns it into an . /// /// The type of input the loader takes. - public interface IDataLoader + public interface IDataLoader : ICanSaveModel { /// /// Produce the data view from the specified input. diff --git a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs index 04d6faf750..4be940a396 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs @@ -4,7 +4,11 @@ using System.IO; using Microsoft.Data.DataView; -using Microsoft.ML.Model; +using Microsoft.ML; +using Microsoft.ML.Data; + +[assembly: LoadableClass(CompositeDataLoader.Summary, typeof(CompositeDataLoader), null, typeof(SignatureLoadModel), + "Composite Loader", CompositeDataLoader.LoaderSignature)] namespace Microsoft.ML.Data { @@ -15,6 +19,10 @@ namespace Microsoft.ML.Data public sealed class CompositeDataLoader : IDataLoader where TLastTransformer : class, ITransformer { + private const string LoaderDirectory = "Loader"; + private const string LegacyLoaderDirectory = "Reader"; + private const string TransformerDirectory = TransformerChain.LoaderSignature; + /// /// The underlying data loader. /// @@ -33,6 +41,24 @@ public CompositeDataLoader(IDataLoader loader, TransformerChain(); } + private CompositeDataLoader(IHost host, ModelLoadContext ctx) + { + if (!ctx.LoadModelOrNull, SignatureLoadModel>(host, out Loader, LegacyLoaderDirectory)) + ctx.LoadModel, SignatureLoadModel>(host, out Loader, LoaderDirectory); + ctx.LoadModel, SignatureLoadModel>(host, out Transformer, TransformerDirectory); + } + + private static CompositeDataLoader Create(IHostEnvironment env, ModelLoadContext ctx) + { + Contracts.CheckValue(env, nameof(env)); + IHost h = env.Register(LoaderSignature); + + h.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); + + return h.Apply("Loading Model", ch => new CompositeDataLoader(h, ctx)); + } + /// /// Produce the data view from the specified input. /// Note that 's are lazy, so no actual loading happens here, just schema validation. @@ -62,6 +88,16 @@ public CompositeDataLoader AppendTransformer(TNewLa return new CompositeDataLoader(Loader, Transformer.Append(transformer)); } + void ICanSaveModel.Save(ModelSaveContext ctx) + { + Contracts.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); + + ctx.SaveModel(Loader, LoaderDirectory); + ctx.SaveModel(Transformer, TransformerDirectory); + } + /// /// Save the contents to a stream, as a "model file". /// @@ -76,47 +112,27 @@ public void SaveTo(IHostEnvironment env, Stream outputStream) using (var rep = RepositoryWriter.CreateNew(outputStream, ch)) { ch.Trace("Saving data loader"); - ModelSaveContext.SaveModel(rep, Loader, "Reader"); + ModelSaveContext.SaveModel(rep, Loader, LoaderDirectory); ch.Trace("Saving transformer chain"); - ModelSaveContext.SaveModel(rep, Transformer, TransformerChain.LoaderSignature); + ModelSaveContext.SaveModel(rep, Transformer, TransformerDirectory); rep.Commit(); } } } - } - /// - /// Utility class to facilitate loading from a stream. - /// - [BestFriend] - internal static class CompositeDataLoader - { - /// - /// Save the contents to a stream, as a "model file". - /// - public static void SaveTo(this IDataLoader loader, IHostEnvironment env, Stream outputStream) - => new CompositeDataLoader(loader).SaveTo(env, outputStream); + internal const string Summary = "A loader that encapsulates a loader and a transformer chain."; - /// - /// Load the pipeline from stream. - /// - public static CompositeDataLoader LoadFrom(IHostEnvironment env, Stream stream) + internal const string LoaderSignature = "CompositeLoader"; + private static VersionInfo GetVersionInfo() { - Contracts.CheckValue(env, nameof(env)); - env.CheckValue(stream, nameof(stream)); - - env.Check(stream.CanRead && stream.CanSeek, "Need a readable and seekable stream to load"); - using (var rep = RepositoryReader.Open(stream, env)) - using (var ch = env.Start("Loading pipeline")) - { - ch.Trace("Loading data loader"); - ModelLoadContext.LoadModel, SignatureLoadModel>(env, out var loader, rep, "Reader"); - - ch.Trace("Loader transformer chain"); - ModelLoadContext.LoadModel, SignatureLoadModel>(env, out var transformerChain, rep, TransformerChain.LoaderSignature); - return new CompositeDataLoader(loader, transformerChain); - } + return new VersionInfo( + modelSignature: "CMPSTLDR", + verWrittenCur: 0x00010001, // Initial + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(CompositeDataLoader<,>).Assembly.FullName); } } } diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs index be48c2659c..c9e5fa06d1 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs @@ -12,7 +12,6 @@ using Microsoft.ML.CommandLine; using Microsoft.ML.Data; using Microsoft.ML.Internal.Utilities; -using Microsoft.ML.Model; [assembly: LoadableClass(TextLoader.Summary, typeof(ILegacyDataLoader), typeof(TextLoader), typeof(TextLoader.Options), typeof(SignatureDataLoader), "Text Loader", "TextLoader", "Text", DocName = "loader/TextLoader.md")] @@ -20,12 +19,15 @@ [assembly: LoadableClass(TextLoader.Summary, typeof(ILegacyDataLoader), typeof(TextLoader), null, typeof(SignatureLoadDataLoader), "Text Loader", TextLoader.LoaderSignature)] +[assembly: LoadableClass(TextLoader.Summary, typeof(TextLoader), null, typeof(SignatureLoadModel), + "Text Loader", TextLoader.LoaderSignature)] + namespace Microsoft.ML.Data { /// /// Loads a text file into an IDataView. Supports basic mapping from input columns to columns. /// - public sealed partial class TextLoader : IDataLoader, ICanSaveModel + public sealed partial class TextLoader : IDataLoader { /// /// Describes how an input column should be mapped to an column. diff --git a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs index 31eba5c376..d4bd677845 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs @@ -255,7 +255,9 @@ public static TransformerChain LoadFrom(IHostEnvironment env, Stre { try { - ModelLoadContext.LoadModel, SignatureLoadModel>(env, out var transformerChain, rep, LoaderSignature); + ModelLoadContext.LoadModelOrNull, SignatureLoadModel>(env, out var transformerChain, rep, LoaderSignature); + if (transformerChain == null) + ModelLoadContext.LoadModel, SignatureLoadModel>(env, out transformerChain, rep, $@"Model\{LoaderSignature}"); return transformerChain; } catch (FormatException ex) diff --git a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs index c32bb74a4b..68300e41ae 100644 --- a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs +++ b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs @@ -45,6 +45,18 @@ protected SubCatalogBase(ModelOperationsCatalog owner) /// A writeable, seekable stream to save to. public void Save(ITransformer model, Stream stream) => model.SaveTo(Environment, stream); + public void Save(IDataLoader model, Stream stream) + { + using (var rep = RepositoryWriter.CreateNew(stream)) + { + ModelSaveContext.SaveModel(rep, model, "Model"); + rep.Commit(); + } + } + + public void Save(IDataLoader loader, ITransformer model, Stream stream) => + Save(new CompositeDataLoader(loader, new TransformerChain(model)), stream); + /// /// Load the model from the stream. /// @@ -52,6 +64,15 @@ protected SubCatalogBase(ModelOperationsCatalog owner) /// The loaded model. public ITransformer Load(Stream stream) => TransformerChain.LoadFrom(Environment, stream); + public IDataLoader LoadAsCompositeDataLoader(Stream stream) + { + using (var rep = RepositoryReader.Open(stream)) + { + ModelLoadContext.LoadModel, SignatureLoadModel>(Environment, out var model, rep, "Model"); + return model; + } + } + /// /// The catalog of model explainability operations. /// diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DeserializationTests.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DeserializationTests.cs index fc8a507bb5..07718e8e01 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DeserializationTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DeserializationTests.cs @@ -53,5 +53,62 @@ public void LoadModelAndExtractPredictor() as BinaryClassificationGamModelParameters; Assert.NotNull(gam); } + + [Fact] + public void SaveAndLoadModelWithLoader() + { + var ml = new MLContext(seed: 1, conc: 1); + var file = new MultiFileSource(GetDataPath(TestDatasets.adult.trainFilename)); + var loader = ml.Data.CreateTextLoader(hasHeader: true, dataSample: file); + var data = loader.Load(file); + + // Pipeline. + var pipeline = ml.BinaryClassification.Trainers.GeneralizedAdditiveModels(); + + // Train. + var model = pipeline.Fit(data); + + // Save and reload. + string modelPath = GetOutputPath(FullTestName + "-model.zip"); + using (var fs = File.Create(modelPath)) + ml.Model.Save(loader, model, fs); + + IDataLoader loadedModel; + ITransformer loadedModelWithoutLoader; + using (var fs = File.OpenRead(modelPath)) + { + loadedModel = ml.Model.LoadAsCompositeDataLoader(fs); + loadedModelWithoutLoader = ml.Model.Load(fs); + } + + // Without deserializing the loader from the model we lose the slot names. + data = ml.Data.LoadFromEnumerable(new[] { new InputData() }); + data = loadedModelWithoutLoader.Transform(data); + Assert.Null(data.Schema["Features"].Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.SlotNames)); + + data = loadedModel.Load(file); + Assert.True(data.Schema["Features"].HasSlotNames(data.Schema["Features"].Type.GetValueCount())); + VBuffer> slotNames = default; + data.Schema["Features"].GetSlotNames(ref slotNames); + var ageIndex = FindIndex(slotNames.GetValues(), "age"); + var transformer = (loadedModel as CompositeDataLoader).Transformer.LastTransformer; + var gamModel = ((transformer as BinaryPredictionTransformer).Model + as CalibratedModelParametersBase).SubModel + as BinaryClassificationGamModelParameters; + var ageBinUpperBounds = gamModel.GetBinUpperBounds(ageIndex); + var ageBinEffects = gamModel.GetBinEffects(ageIndex); + } + + private int FindIndex(ReadOnlySpan> values, string slotName) + { + int index = 0; + foreach (var value in values) + { + if (value.Span.SequenceEqual(slotName.AsSpan())) + return index; + index++; + } + return -1; + } } }