From db9a74afa98170ad84789544837a41787fa5765f Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Tue, 5 Mar 2019 12:25:34 -0800 Subject: [PATCH 01/18] Add save/load APIs for IDataLoader --- src/Microsoft.ML.Core/Data/IEstimator.cs | 2 +- .../DataLoadSave/CompositeDataLoader.cs | 81 ++++++++----- .../DataLoadSave/Text/TextLoader.cs | 79 ++++++------ .../DataLoadSave/TransformerChain.cs | 6 +- .../EntryPoints/PredictorModelImpl.cs | 2 +- .../EntryPoints/SummarizePredictor.cs | 2 +- .../Model/ModelOperationsCatalog.cs | 21 ++++ .../Prediction/Calibrator.cs | 39 +++--- .../Scorers/PredictionTransformer.cs | 60 +++++---- .../EntryPoints/PipelineEnsemble.cs | 2 +- .../TreeEnsemble/TreeEnsembleCombiner.cs | 50 ++++---- .../TreeEnsembleFeaturizer.cs | 2 +- .../Standard/SdcaBinary.cs | 2 +- .../Standard/StochasticTrainerBase.cs | 2 +- .../Api/Estimators/DeserializationTests.cs | 114 ++++++++++++++++++ 15 files changed, 311 insertions(+), 153 deletions(-) create mode 100644 test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DeserializationTests.cs diff --git a/src/Microsoft.ML.Core/Data/IEstimator.cs b/src/Microsoft.ML.Core/Data/IEstimator.cs index eb6e5e1b40..ab5ca4c880 100644 --- a/src/Microsoft.ML.Core/Data/IEstimator.cs +++ b/src/Microsoft.ML.Core/Data/IEstimator.cs @@ -224,7 +224,7 @@ internal bool TryFindColumn(string name, out Column column) /// The 'data loader' takes a certain kind of input and turns it into an . /// /// The type of input the loader takes. - public interface IDataLoader + public interface IDataLoader : ICanSaveModel { /// /// Produce the data view from the specified input. diff --git a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs index 9d7d607709..98fb04f309 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs @@ -3,8 +3,13 @@ // See the LICENSE file in the project root for more information. using System.IO; +using Microsoft.ML; +using Microsoft.ML.Data; using Microsoft.ML.Runtime; +[assembly: LoadableClass(CompositeDataLoader.Summary, typeof(CompositeDataLoader), null, typeof(SignatureLoadModel), + "Composite Loader", CompositeDataLoader.LoaderSignature)] + namespace Microsoft.ML.Data { /// @@ -14,6 +19,10 @@ namespace Microsoft.ML.Data public sealed class CompositeDataLoader : IDataLoader where TLastTransformer : class, ITransformer { + private const string LoaderDirectory = "Loader"; + private const string LegacyLoaderDirectory = "Reader"; + private const string TransformerDirectory = TransformerChain.LoaderSignature; + /// /// The underlying data loader. /// @@ -32,6 +41,24 @@ public CompositeDataLoader(IDataLoader loader, TransformerChain(); } + private CompositeDataLoader(IHost host, ModelLoadContext ctx) + { + if (!ctx.LoadModelOrNull, SignatureLoadModel>(host, out Loader, LegacyLoaderDirectory)) + ctx.LoadModel, SignatureLoadModel>(host, out Loader, LoaderDirectory); + ctx.LoadModel, SignatureLoadModel>(host, out Transformer, TransformerDirectory); + } + + private static CompositeDataLoader Create(IHostEnvironment env, ModelLoadContext ctx) + { + Contracts.CheckValue(env, nameof(env)); + IHost h = env.Register(LoaderSignature); + + h.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); + + return h.Apply("Loading Model", ch => new CompositeDataLoader(h, ctx)); + } + /// /// Produce the data view from the specified input. /// Note that 's are lazy, so no actual loading happens here, just schema validation. @@ -61,6 +88,16 @@ public CompositeDataLoader AppendTransformer(TNewLa return new CompositeDataLoader(Loader, Transformer.Append(transformer)); } + void ICanSaveModel.Save(ModelSaveContext ctx) + { + Contracts.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); + + ctx.SaveModel(Loader, LoaderDirectory); + ctx.SaveModel(Transformer, TransformerDirectory); + } + /// /// Save the contents to a stream, as a "model file". /// @@ -75,47 +112,27 @@ public void SaveTo(IHostEnvironment env, Stream outputStream) using (var rep = RepositoryWriter.CreateNew(outputStream, ch)) { ch.Trace("Saving data loader"); - ModelSaveContext.SaveModel(rep, Loader, "Reader"); + ModelSaveContext.SaveModel(rep, Loader, LoaderDirectory); ch.Trace("Saving transformer chain"); - ModelSaveContext.SaveModel(rep, Transformer, TransformerChain.LoaderSignature); + ModelSaveContext.SaveModel(rep, Transformer, TransformerDirectory); rep.Commit(); } } } - } - /// - /// Utility class to facilitate loading from a stream. - /// - [BestFriend] - internal static class CompositeDataLoader - { - /// - /// Save the contents to a stream, as a "model file". - /// - public static void SaveTo(this IDataLoader loader, IHostEnvironment env, Stream outputStream) - => new CompositeDataLoader(loader).SaveTo(env, outputStream); + internal const string Summary = "A loader that encapsulates a loader and a transformer chain."; - /// - /// Load the pipeline from stream. - /// - public static CompositeDataLoader LoadFrom(IHostEnvironment env, Stream stream) + internal const string LoaderSignature = "CompositeLoader"; + private static VersionInfo GetVersionInfo() { - Contracts.CheckValue(env, nameof(env)); - env.CheckValue(stream, nameof(stream)); - - env.Check(stream.CanRead && stream.CanSeek, "Need a readable and seekable stream to load"); - using (var rep = RepositoryReader.Open(stream, env)) - using (var ch = env.Start("Loading pipeline")) - { - ch.Trace("Loading data loader"); - ModelLoadContext.LoadModel, SignatureLoadModel>(env, out var loader, rep, "Reader"); - - ch.Trace("Loader transformer chain"); - ModelLoadContext.LoadModel, SignatureLoadModel>(env, out var transformerChain, rep, TransformerChain.LoaderSignature); - return new CompositeDataLoader(loader, transformerChain); - } + return new VersionInfo( + modelSignature: "CMPSTLDR", + verWrittenCur: 0x00010001, // Initial + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(CompositeDataLoader<,>).Assembly.FullName); } } } diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs index 85c96e5476..9caf93c626 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs @@ -19,12 +19,15 @@ [assembly: LoadableClass(TextLoader.Summary, typeof(ILegacyDataLoader), typeof(TextLoader), null, typeof(SignatureLoadDataLoader), "Text Loader", TextLoader.LoaderSignature)] +[assembly: LoadableClass(TextLoader.Summary, typeof(TextLoader), null, typeof(SignatureLoadModel), + "Text Loader", TextLoader.LoaderSignature)] + namespace Microsoft.ML.Data { /// /// Loads a text file into an IDataView. Supports basic mapping from input columns to columns. /// - public sealed partial class TextLoader : IDataLoader, ICanSaveModel + public sealed partial class TextLoader : IDataLoader { /// /// Describes how an input column should be mapped to an column. @@ -1189,31 +1192,31 @@ private char NormalizeSeparator(string sep) { switch (sep) { - case "space": - case " ": - return ' '; - case "tab": - case "\t": - return '\t'; - case "comma": - case ",": - return ','; - case "colon": - case ":": - _host.CheckUserArg((_flags & OptionFlags.AllowSparse) == 0, nameof(Options.Separator), - "When the separator is colon, turn off allowSparse"); - return ':'; - case "semicolon": - case ";": - return ';'; - case "bar": - case "|": - return '|'; - default: - char ch = sep[0]; - if (sep.Length != 1 || ch < ' ' || '0' <= ch && ch <= '9' || ch == '"') - throw _host.ExceptUserArg(nameof(Options.Separator), "Illegal separator: '{0}'", sep); - return sep[0]; + case "space": + case " ": + return ' '; + case "tab": + case "\t": + return '\t'; + case "comma": + case ",": + return ','; + case "colon": + case ":": + _host.CheckUserArg((_flags & OptionFlags.AllowSparse) == 0, nameof(Options.Separator), + "When the separator is colon, turn off allowSparse"); + return ':'; + case "semicolon": + case ";": + return ';'; + case "bar": + case "|": + return '|'; + default: + char ch = sep[0]; + if (sep.Length != 1 || ch < ' ' || '0' <= ch && ch <= '9' || ch == '"') + throw _host.ExceptUserArg(nameof(Options.Separator), "Illegal separator: '{0}'", sep); + return sep[0]; } } @@ -1310,7 +1313,7 @@ private static bool TryParseSchema(IHost host, IMultiStreamSource files, error = false; options = optionsNew; - LDone: + LDone: return !error; } } @@ -1470,20 +1473,20 @@ internal static TextLoader CreateTextLoader(IHostEnvironment host, InternalDataKind dk; switch (memberInfo) { - case FieldInfo field: - if (!InternalDataKindExtensions.TryGetDataKind(field.FieldType.IsArray ? field.FieldType.GetElementType() : field.FieldType, out dk)) - throw Contracts.Except($"Field {memberInfo.Name} is of unsupported type."); + case FieldInfo field: + if (!InternalDataKindExtensions.TryGetDataKind(field.FieldType.IsArray ? field.FieldType.GetElementType() : field.FieldType, out dk)) + throw Contracts.Except($"Field {memberInfo.Name} is of unsupported type."); - break; + break; - case PropertyInfo property: - if (!InternalDataKindExtensions.TryGetDataKind(property.PropertyType.IsArray ? property.PropertyType.GetElementType() : property.PropertyType, out dk)) - throw Contracts.Except($"Property {memberInfo.Name} is of unsupported type."); - break; + case PropertyInfo property: + if (!InternalDataKindExtensions.TryGetDataKind(property.PropertyType.IsArray ? property.PropertyType.GetElementType() : property.PropertyType, out dk)) + throw Contracts.Except($"Property {memberInfo.Name} is of unsupported type."); + break; - default: - Contracts.Assert(false); - throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo"); + default: + Contracts.Assert(false); + throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo"); } column.Type = dk; diff --git a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs index 6e78598b2c..908e910eae 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs @@ -237,7 +237,7 @@ IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema) /// /// Saving/loading routines for transformer chains. /// - internal static class TransformerChain + public static class TransformerChain { public const string LoaderSignature = "TransformerChain"; @@ -256,7 +256,9 @@ public static TransformerChain LoadFrom(IHostEnvironment env, Stre { try { - ModelLoadContext.LoadModel, SignatureLoadModel>(env, out var transformerChain, rep, LoaderSignature); + ModelLoadContext.LoadModelOrNull, SignatureLoadModel>(env, out var transformerChain, rep, LoaderSignature); + if (transformerChain == null) + ModelLoadContext.LoadModel, SignatureLoadModel>(env, out transformerChain, rep, $@"Model\{LoaderSignature}"); return transformerChain; } catch (FormatException ex) diff --git a/src/Microsoft.ML.Data/EntryPoints/PredictorModelImpl.cs b/src/Microsoft.ML.Data/EntryPoints/PredictorModelImpl.cs index 1e4431ca8c..0b69faabe6 100644 --- a/src/Microsoft.ML.Data/EntryPoints/PredictorModelImpl.cs +++ b/src/Microsoft.ML.Data/EntryPoints/PredictorModelImpl.cs @@ -113,7 +113,7 @@ internal override string[] GetLabelInfo(IHostEnvironment env, out DataViewType l var calibrated = predictor as IWeaklyTypedCalibratedModelParameters; while (calibrated != null) { - predictor = calibrated.WeeklyTypedSubModel; + predictor = calibrated.WeaklyTypedSubModel; calibrated = predictor as IWeaklyTypedCalibratedModelParameters; } var canGetTrainingLabelNames = predictor as ICanGetTrainingLabelNames; diff --git a/src/Microsoft.ML.Data/EntryPoints/SummarizePredictor.cs b/src/Microsoft.ML.Data/EntryPoints/SummarizePredictor.cs index dad87d6d1c..6789348229 100644 --- a/src/Microsoft.ML.Data/EntryPoints/SummarizePredictor.cs +++ b/src/Microsoft.ML.Data/EntryPoints/SummarizePredictor.cs @@ -51,7 +51,7 @@ internal static IDataView GetSummaryAndStats(IHostEnvironment env, IPredictor pr var calibrated = predictor as IWeaklyTypedCalibratedModelParameters; while (calibrated != null) { - predictor = calibrated.WeeklyTypedSubModel; + predictor = calibrated.WeaklyTypedSubModel; calibrated = predictor as IWeaklyTypedCalibratedModelParameters; } diff --git a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs index 79b39ffa3c..28fadf1c7c 100644 --- a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs +++ b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs @@ -33,6 +33,18 @@ internal ModelOperationsCatalog(IHostEnvironment env) /// A writeable, seekable stream to save to. public void Save(ITransformer model, Stream stream) => model.SaveTo(_env, stream); + public void Save(IDataLoader model, Stream stream) + { + using (var rep = RepositoryWriter.CreateNew(stream)) + { + ModelSaveContext.SaveModel(rep, model, "Model"); + rep.Commit(); + } + } + + public void Save(IDataLoader loader, ITransformer model, Stream stream) => + Save(new CompositeDataLoader(loader, new TransformerChain(model)), stream); + /// /// Load the model from the stream. /// @@ -40,6 +52,15 @@ internal ModelOperationsCatalog(IHostEnvironment env) /// The loaded model. public ITransformer Load(Stream stream) => TransformerChain.LoadFrom(_env, stream); + public IDataLoader LoadAsCompositeDataLoader(Stream stream) + { + using (var rep = RepositoryReader.Open(stream)) + { + ModelLoadContext.LoadModel, SignatureLoadModel>(_env, out var model, rep, "Model"); + return model; + } + } + /// /// Load the model from a file path. /// diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index a8246c3e62..8c9ffdc08c 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -58,19 +58,19 @@ "Naive Calibration Executor", NaiveCalibrator.LoaderSignature)] -[assembly: LoadableClass(typeof(ValueMapperCalibratedModelParameters, ICalibrator>), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(ValueMapperCalibratedModelParameters), null, typeof(SignatureLoadModel), "Calibrated Predictor Executor", ValueMapperCalibratedModelParameters, ICalibrator>.LoaderSignature, "BulkCaliPredExec")] -[assembly: LoadableClass(typeof(FeatureWeightsCalibratedModelParameters, ICalibrator>), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(FeatureWeightsCalibratedModelParameters), null, typeof(SignatureLoadModel), "Feature Weights Calibrated Predictor Executor", FeatureWeightsCalibratedModelParameters, ICalibrator>.LoaderSignature)] -[assembly: LoadableClass(typeof(ParameterMixingCalibratedModelParameters, ICalibrator>), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(ParameterMixingCalibratedModelParameters), null, typeof(SignatureLoadModel), "Parameter Mixing Calibrated Predictor Executor", ParameterMixingCalibratedModelParameters, ICalibrator>.LoaderSignature)] -[assembly: LoadableClass(typeof(SchemaBindableCalibratedModelParameters, ICalibrator>), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(SchemaBindableCalibratedModelParameters), null, typeof(SignatureLoadModel), "Schema Bindable Calibrated Predictor", SchemaBindableCalibratedModelParameters, ICalibrator>.LoaderSignature)] [assembly: LoadableClass(typeof(void), typeof(Calibrate), null, typeof(SignatureEntryPointModule), "Calibrate")] @@ -147,8 +147,8 @@ internal interface ISelfCalibratingPredictor [BestFriend] internal interface IWeaklyTypedCalibratedModelParameters { - IPredictorProducing WeeklyTypedSubModel { get; } - ICalibrator WeeklyTypedCalibrator { get; } + IPredictorProducing WeaklyTypedSubModel { get; } + ICalibrator WeaklyTypedCalibrator { get; } } /// @@ -186,8 +186,8 @@ public abstract class CalibratedModelParametersBase : public TCalibrator Calibrator { get; } // Type-unsafed accessors of strongly-typed members. - IPredictorProducing IWeaklyTypedCalibratedModelParameters.WeeklyTypedSubModel => (IPredictorProducing)SubModel; - ICalibrator IWeaklyTypedCalibratedModelParameters.WeeklyTypedCalibrator => Calibrator; + IPredictorProducing IWeaklyTypedCalibratedModelParameters.WeaklyTypedSubModel => (IPredictorProducing)SubModel; + ICalibrator IWeaklyTypedCalibratedModelParameters.WeaklyTypedCalibrator => Calibrator; PredictionKind IPredictor.PredictionKind => ((IPredictorProducing)SubModel).PredictionKind; @@ -198,6 +198,7 @@ private protected CalibratedModelParametersBase(IHostEnvironment env, string nam Host = env.Register(name); Host.CheckValue(predictor, nameof(predictor)); Host.CheckValue(calibrator, nameof(calibrator)); + Host.Assert(predictor is IPredictorProducing); SubModel = predictor; Calibrator = calibrator; @@ -270,7 +271,7 @@ internal abstract class ValueMapperCalibratedModelParametersBase, IValueMapperDist, IFeatureContributionMapper, ICalculateFeatureContribution, IDistCanSavePfa, IDistCanSaveOnnx - where TSubModel : class, IPredictorProducing + where TSubModel : class where TCalibrator : class, ICalibrator { private readonly IValueMapper _mapper; @@ -380,7 +381,7 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string [BestFriend] internal sealed class ValueMapperCalibratedModelParameters : ValueMapperCalibratedModelParametersBase, ICanSaveModel - where TSubModel : class, IPredictorProducing + where TSubModel : class where TCalibrator : class, ICalibrator { internal ValueMapperCalibratedModelParameters(IHostEnvironment env, TSubModel predictor, TCalibrator calibrator) @@ -442,7 +443,7 @@ internal sealed class FeatureWeightsCalibratedModelParameters, IPredictorWithFeatureWeights, ICanSaveModel - where TSubModel : class, IPredictorWithFeatureWeights + where TSubModel : class where TCalibrator : class, ICalibrator { private readonly IPredictorWithFeatureWeights _featureWeights; @@ -451,7 +452,8 @@ internal FeatureWeightsCalibratedModelParameters(IHostEnvironment env, TSubModel TCalibrator calibrator) : base(env, RegistrationName, predictor, calibrator) { - _featureWeights = predictor; + Host.Assert(predictor is IPredictorWithFeatureWeights); + _featureWeights = predictor as IPredictorWithFeatureWeights; } internal const string LoaderSignature = "FeatWCaliPredExec"; @@ -506,7 +508,7 @@ internal sealed class ParameterMixingCalibratedModelParameters, IPredictorWithFeatureWeights, ICanSaveModel - where TSubModel : class, IPredictorWithFeatureWeights + where TSubModel : class where TCalibrator : class, ICalibrator { private readonly IPredictorWithFeatureWeights _featureWeights; @@ -516,7 +518,8 @@ internal ParameterMixingCalibratedModelParameters(IHostEnvironment env, TSubMode { Host.Check(predictor is IParameterMixer, "Predictor does not implement " + nameof(IParameterMixer)); Host.Check(calibrator is IParameterMixer, "Calibrator does not implement " + nameof(IParameterMixer)); - _featureWeights = predictor; + Host.Assert(predictor is IPredictorWithFeatureWeights); + _featureWeights = predictor as IPredictorWithFeatureWeights; } internal const string LoaderSignature = "PMixCaliPredExec"; @@ -538,7 +541,7 @@ private ParameterMixingCalibratedModelParameters(IHostEnvironment env, ModelLoad { Host.Check(SubModel is IParameterMixer, "Predictor does not implement " + nameof(IParameterMixer)); Host.Check(SubModel is IPredictorWithFeatureWeights, "Predictor does not implement " + nameof(IPredictorWithFeatureWeights)); - _featureWeights = SubModel; + _featureWeights = SubModel as IPredictorWithFeatureWeights; } private static ParameterMixingCalibratedModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) @@ -587,7 +590,7 @@ IParameterMixer IParameterMixer.CombineParameters(IList : CalibratedModelParametersBase, ISchemaBindableMapper, ICanSaveModel, IBindableCanSavePfa, IBindableCanSaveOnnx, IFeatureContributionMapper - where TSubModel : class, IPredictorProducing + where TSubModel : class where TCalibrator : class, ICalibrator { private sealed class Bound : ISchemaBoundRowMapper @@ -697,14 +700,14 @@ private static VersionInfo GetVersionInfo() internal SchemaBindableCalibratedModelParameters(IHostEnvironment env, TSubModel predictor, TCalibrator calibrator) : base(env, LoaderSignature, predictor, calibrator) { - _bindable = ScoreUtils.GetSchemaBindableMapper(Host, SubModel); + _bindable = ScoreUtils.GetSchemaBindableMapper(Host, SubModel as IPredictorProducing); _featureContribution = SubModel as IFeatureContributionMapper; } private SchemaBindableCalibratedModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, LoaderSignature, GetPredictor(env, ctx), GetCalibrator(env, ctx)) { - _bindable = ScoreUtils.GetSchemaBindableMapper(Host, SubModel); + _bindable = ScoreUtils.GetSchemaBindableMapper(Host, SubModel as IPredictorProducing); _featureContribution = SubModel as IFeatureContributionMapper; } diff --git a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs index 38f3c3b38f..5836de85da 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs @@ -8,22 +8,22 @@ using Microsoft.ML.Data.IO; using Microsoft.ML.Runtime; -[assembly: LoadableClass(typeof(BinaryPredictionTransformer>), typeof(BinaryPredictionTransformer), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(BinaryPredictionTransformer), typeof(BinaryPredictionTransformer), null, typeof(SignatureLoadModel), "", BinaryPredictionTransformer.LoaderSignature)] -[assembly: LoadableClass(typeof(MulticlassPredictionTransformer>>), typeof(MulticlassPredictionTransformer), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(MulticlassPredictionTransformer), typeof(MulticlassPredictionTransformer), null, typeof(SignatureLoadModel), "", MulticlassPredictionTransformer.LoaderSignature)] -[assembly: LoadableClass(typeof(RegressionPredictionTransformer>), typeof(RegressionPredictionTransformer), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(RegressionPredictionTransformer), typeof(RegressionPredictionTransformer), null, typeof(SignatureLoadModel), "", RegressionPredictionTransformer.LoaderSignature)] -[assembly: LoadableClass(typeof(RankingPredictionTransformer>), typeof(RankingPredictionTransformer), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(RankingPredictionTransformer), typeof(RankingPredictionTransformer), null, typeof(SignatureLoadModel), "", RankingPredictionTransformer.LoaderSignature)] -[assembly: LoadableClass(typeof(AnomalyPredictionTransformer>), typeof(AnomalyPredictionTransformer), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(AnomalyPredictionTransformer), typeof(AnomalyPredictionTransformer), null, typeof(SignatureLoadModel), "", AnomalyPredictionTransformer.LoaderSignature)] -[assembly: LoadableClass(typeof(ClusteringPredictionTransformer>>), typeof(ClusteringPredictionTransformer), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(ClusteringPredictionTransformer), typeof(ClusteringPredictionTransformer), null, typeof(SignatureLoadModel), "", ClusteringPredictionTransformer.LoaderSignature)] namespace Microsoft.ML.Data @@ -51,8 +51,7 @@ public abstract class PredictionTransformerBase : IPredictionTransformer private protected readonly IHost Host; [BestFriend] private protected ISchemaBindableMapper BindableMapper; - [BestFriend] - private protected DataViewSchema TrainSchema; + protected DataViewSchema TrainSchema; /// /// Whether a call to should succeed, on an @@ -142,8 +141,7 @@ IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema) private protected abstract void SaveModel(ModelSaveContext ctx); - [BestFriend] - private protected void SaveModelCore(ModelSaveContext ctx) + protected void SaveModelCore(ModelSaveContext ctx) { // *** Binary format *** // @@ -235,14 +233,14 @@ public sealed override DataViewSchema GetOutputSchema(DataViewSchema inputSchema return Transform(new EmptyDataView(Host, inputSchema)).Schema; } - private protected sealed override void SaveModel(ModelSaveContext ctx) + private protected override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); SaveCore(ctx); } - private protected virtual void SaveCore(ModelSaveContext ctx) + protected virtual void SaveCore(ModelSaveContext ctx) { SaveModelCore(ctx); ctx.SaveStringOrNull(FeatureColumn); @@ -268,7 +266,7 @@ public sealed class AnomalyPredictionTransformer : SingleFeaturePredicti [BestFriend] internal AnomalyPredictionTransformer(IHostEnvironment env, TModel model, DataViewSchema inputSchema, string featureColumn, float threshold = 0f, string thresholdColumn = DefaultColumnNames.Score) - : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(AnomalyPredictionTransformer)),model, inputSchema, featureColumn) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(AnomalyPredictionTransformer)), model, inputSchema, featureColumn) { Host.CheckNonEmpty(thresholdColumn, nameof(thresholdColumn)); Threshold = threshold; @@ -297,7 +295,7 @@ private void SetScorer() Scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); } - private protected override void SaveCore(ModelSaveContext ctx) + protected override void SaveCore(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.SetVersionInfo(GetVersionInfo()); @@ -366,7 +364,7 @@ private void SetScorer() Scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); } - private protected override void SaveCore(ModelSaveContext ctx) + protected override void SaveCore(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.SetVersionInfo(GetVersionInfo()); @@ -430,7 +428,7 @@ private void SetScorer() Scorer = new MulticlassClassificationScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); } - private protected override void SaveCore(ModelSaveContext ctx) + protected override void SaveCore(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.SetVersionInfo(GetVersionInfo()); @@ -475,7 +473,7 @@ internal RegressionPredictionTransformer(IHostEnvironment env, ModelLoadContext Scorer = GetGenericScorer(); } - private protected override void SaveCore(ModelSaveContext ctx) + protected override void SaveCore(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.SetVersionInfo(GetVersionInfo()); @@ -517,7 +515,7 @@ internal RankingPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx Scorer = GetGenericScorer(); } - private protected override void SaveCore(ModelSaveContext ctx) + protected override void SaveCore(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.SetVersionInfo(GetVersionInfo()); @@ -569,7 +567,7 @@ internal ClusteringPredictionTransformer(IHostEnvironment env, ModelLoadContext Scorer = new ClusteringScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); } - private protected override void SaveCore(ModelSaveContext ctx) + protected override void SaveCore(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.SetVersionInfo(GetVersionInfo()); @@ -596,47 +594,47 @@ internal static class BinaryPredictionTransformer { public const string LoaderSignature = "BinaryPredXfer"; - public static BinaryPredictionTransformer> Create(IHostEnvironment env, ModelLoadContext ctx) - => new BinaryPredictionTransformer>(env, ctx); + public static BinaryPredictionTransformer Create(IHostEnvironment env, ModelLoadContext ctx) + => new BinaryPredictionTransformer(env, ctx); } internal static class MulticlassPredictionTransformer { public const string LoaderSignature = "MulticlassPredXfer"; - public static MulticlassPredictionTransformer>> Create(IHostEnvironment env, ModelLoadContext ctx) - => new MulticlassPredictionTransformer>>(env, ctx); + public static MulticlassPredictionTransformer Create(IHostEnvironment env, ModelLoadContext ctx) + => new MulticlassPredictionTransformer(env, ctx); } internal static class RegressionPredictionTransformer { public const string LoaderSignature = "RegressionPredXfer"; - public static RegressionPredictionTransformer> Create(IHostEnvironment env, ModelLoadContext ctx) - => new RegressionPredictionTransformer>(env, ctx); + public static RegressionPredictionTransformer Create(IHostEnvironment env, ModelLoadContext ctx) + => new RegressionPredictionTransformer(env, ctx); } internal static class RankingPredictionTransformer { public const string LoaderSignature = "RankingPredXfer"; - public static RankingPredictionTransformer> Create(IHostEnvironment env, ModelLoadContext ctx) - => new RankingPredictionTransformer>(env, ctx); + public static RankingPredictionTransformer Create(IHostEnvironment env, ModelLoadContext ctx) + => new RankingPredictionTransformer(env, ctx); } internal static class AnomalyPredictionTransformer { public const string LoaderSignature = "AnomalyPredXfer"; - public static AnomalyPredictionTransformer> Create(IHostEnvironment env, ModelLoadContext ctx) - => new AnomalyPredictionTransformer>(env, ctx); + public static AnomalyPredictionTransformer Create(IHostEnvironment env, ModelLoadContext ctx) + => new AnomalyPredictionTransformer(env, ctx); } internal static class ClusteringPredictionTransformer { public const string LoaderSignature = "ClusteringPredXfer"; - public static ClusteringPredictionTransformer>> Create(IHostEnvironment env, ModelLoadContext ctx) - => new ClusteringPredictionTransformer>>(env, ctx); + public static ClusteringPredictionTransformer Create(IHostEnvironment env, ModelLoadContext ctx) + => new ClusteringPredictionTransformer(env, ctx); } } diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/PipelineEnsemble.cs b/src/Microsoft.ML.Ensemble/EntryPoints/PipelineEnsemble.cs index b086eea738..553c5cf6a1 100644 --- a/src/Microsoft.ML.Ensemble/EntryPoints/PipelineEnsemble.cs +++ b/src/Microsoft.ML.Ensemble/EntryPoints/PipelineEnsemble.cs @@ -38,7 +38,7 @@ public static SummaryOutput Summarize(IHostEnvironment env, SummarizePredictor.I var calibrated = predictor as IWeaklyTypedCalibratedModelParameters; while (calibrated != null) { - predictor = calibrated.WeeklyTypedSubModel; + predictor = calibrated.WeaklyTypedSubModel; calibrated = predictor as IWeaklyTypedCalibratedModelParameters; } var ensemble = predictor as SchemaBindablePipelineEnsembleBase; diff --git a/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsembleCombiner.cs b/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsembleCombiner.cs index 6d847c4edc..b938ee3f5b 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsembleCombiner.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsembleCombiner.cs @@ -23,14 +23,14 @@ public TreeEnsembleCombiner(IHostEnvironment env, PredictionKind kind) _host = env.Register("TreeEnsembleCombiner"); switch (kind) { - case PredictionKind.BinaryClassification: - case PredictionKind.Regression: - case PredictionKind.Ranking: - _kind = kind; - break; - default: - throw _host.ExceptUserArg(nameof(kind), $"Tree ensembles can be either of type {nameof(PredictionKind.BinaryClassification)}, " + - $"{nameof(PredictionKind.Regression)} or {nameof(PredictionKind.Ranking)}"); + case PredictionKind.BinaryClassification: + case PredictionKind.Regression: + case PredictionKind.Ranking: + _kind = kind; + break; + default: + throw _host.ExceptUserArg(nameof(kind), $"Tree ensembles can be either of type {nameof(PredictionKind.BinaryClassification)}, " + + $"{nameof(PredictionKind.Regression)} or {nameof(PredictionKind.Ranking)}"); } } @@ -52,11 +52,11 @@ IPredictor IModelCombiner.CombineModels(IEnumerable models) var calibrated = predictor as IWeaklyTypedCalibratedModelParameters; double paramA = 1; if (calibrated != null) - _host.Check(calibrated.WeeklyTypedCalibrator is PlattCalibrator, + _host.Check(calibrated.WeaklyTypedCalibrator is PlattCalibrator, "Combining FastTree models can only be done when the models are calibrated with Platt calibrator"); - predictor = calibrated.WeeklyTypedSubModel; - paramA = -((PlattCalibrator)calibrated.WeeklyTypedCalibrator).Slope; + predictor = calibrated.WeaklyTypedSubModel; + paramA = -((PlattCalibrator)calibrated.WeaklyTypedCalibrator).Slope; var tree = predictor as TreeEnsembleModelParameters; @@ -99,20 +99,20 @@ IPredictor IModelCombiner.CombineModels(IEnumerable models) switch (_kind) { - case PredictionKind.BinaryClassification: - if (!binaryClassifier) - return new FastTreeBinaryModelParameters(_host, ensemble, featureCount, null); - - var cali = new PlattCalibrator(_host, -1, 0); - var fastTreeModel = new FastTreeBinaryModelParameters(_host, ensemble, featureCount, null); - return new FeatureWeightsCalibratedModelParameters(_host, fastTreeModel, cali); - case PredictionKind.Regression: - return new FastTreeRegressionModelParameters(_host, ensemble, featureCount, null); - case PredictionKind.Ranking: - return new FastTreeRankingModelParameters(_host, ensemble, featureCount, null); - default: - _host.Assert(false); - throw _host.ExceptNotSupp(); + case PredictionKind.BinaryClassification: + if (!binaryClassifier) + return new FastTreeBinaryModelParameters(_host, ensemble, featureCount, null); + + var cali = new PlattCalibrator(_host, -1, 0); + var fastTreeModel = new FastTreeBinaryModelParameters(_host, ensemble, featureCount, null); + return new FeatureWeightsCalibratedModelParameters(_host, fastTreeModel, cali); + case PredictionKind.Regression: + return new FastTreeRegressionModelParameters(_host, ensemble, featureCount, null); + case PredictionKind.Ranking: + return new FastTreeRankingModelParameters(_host, ensemble, featureCount, null); + default: + _host.Assert(false); + throw _host.ExceptNotSupp(); } } } diff --git a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs index 92c52d3eb7..c4d2b191f4 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs @@ -590,7 +590,7 @@ private static IDataTransform Create(IHostEnvironment env, Arguments args, IData // Make sure that the given predictor has the correct number of input features. if (predictor is IWeaklyTypedCalibratedModelParameters calibrated) - predictor = calibrated.WeeklyTypedSubModel; + predictor = calibrated.WeaklyTypedSubModel; // Predictor should be a TreeEnsembleModelParameters, which implements IValueMapper, so this should // be non-null. var vm = predictor as IValueMapper; diff --git a/src/Microsoft.ML.StandardTrainers/Standard/SdcaBinary.cs b/src/Microsoft.ML.StandardTrainers/Standard/SdcaBinary.cs index 890d9acb38..dc4653365a 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/SdcaBinary.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/SdcaBinary.cs @@ -72,7 +72,7 @@ private protected override TModel TrainModelCore(TrainContext context) { var preparedData = PrepareDataFromTrainingExamples(ch, context.TrainingSet, out int weightSetCount); var initPred = context.InitialPredictor; - var linInitPred = (initPred as IWeaklyTypedCalibratedModelParameters)?.WeeklyTypedSubModel as LinearModelParameters; + var linInitPred = (initPred as IWeaklyTypedCalibratedModelParameters)?.WeaklyTypedSubModel as LinearModelParameters; linInitPred = linInitPred ?? initPred as LinearModelParameters; Host.CheckParam(context.InitialPredictor == null || linInitPred != null, nameof(context), "Initial predictor was not a linear predictor."); diff --git a/src/Microsoft.ML.StandardTrainers/Standard/StochasticTrainerBase.cs b/src/Microsoft.ML.StandardTrainers/Standard/StochasticTrainerBase.cs index 154285d926..28583b0ab7 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/StochasticTrainerBase.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/StochasticTrainerBase.cs @@ -35,7 +35,7 @@ private protected override TModel TrainModelCore(TrainContext context) var preparedData = PrepareDataFromTrainingExamples(ch, context.TrainingSet, out int weightSetCount); var initPred = context.InitialPredictor; // Try extract linear model from calibrated predictor. - var linInitPred = (initPred as IWeaklyTypedCalibratedModelParameters)?.WeeklyTypedSubModel as LinearModelParameters; + var linInitPred = (initPred as IWeaklyTypedCalibratedModelParameters)?.WeaklyTypedSubModel as LinearModelParameters; // If the initial predictor is not calibrated, it should be a linear model. linInitPred = linInitPred ?? initPred as LinearModelParameters; Host.CheckParam(context.InitialPredictor == null || linInitPred != null, nameof(context), diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DeserializationTests.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DeserializationTests.cs new file mode 100644 index 0000000000..07718e8e01 --- /dev/null +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DeserializationTests.cs @@ -0,0 +1,114 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.IO; +using System.Linq; +using Microsoft.ML.Calibrators; +using Microsoft.ML.Data; +using Microsoft.ML.RunTests; +using Microsoft.ML.Trainers.FastTree; +using Xunit; + +namespace Microsoft.ML.Tests.Scenarios.Api +{ + public partial class ApiScenariosTests + { + private class InputData + { + [LoadColumn(0)] + public float Label { get; set; } + [LoadColumn(9, 14)] + [VectorType(6)] + public float[] Features { get; set; } + } + + [Fact] + public void LoadModelAndExtractPredictor() + { + var ml = new MLContext(seed: 1, conc: 1); + var file = new MultiFileSource(GetDataPath(TestDatasets.adult.trainFilename)); + var loader = ml.Data.CreateTextLoader(hasHeader: true, dataSample: file); + var data = loader.Load(file); + + // Pipeline. + var pipeline = ml.BinaryClassification.Trainers.GeneralizedAdditiveModels(); + + // Train. + var model = pipeline.Fit(data); + + // Save and reload. + string modelPath = GetOutputPath(FullTestName + "-model.zip"); + using (var fs = File.Create(modelPath)) + ml.Model.Save(model, fs); + + ITransformer loadedModel; + using (var fs = File.OpenRead(modelPath)) + loadedModel = ml.Model.Load(fs); + + var gam = (((loadedModel as TransformerChain).LastTransformer + as BinaryPredictionTransformer).Model + as CalibratedModelParametersBase).SubModel + as BinaryClassificationGamModelParameters; + Assert.NotNull(gam); + } + + [Fact] + public void SaveAndLoadModelWithLoader() + { + var ml = new MLContext(seed: 1, conc: 1); + var file = new MultiFileSource(GetDataPath(TestDatasets.adult.trainFilename)); + var loader = ml.Data.CreateTextLoader(hasHeader: true, dataSample: file); + var data = loader.Load(file); + + // Pipeline. + var pipeline = ml.BinaryClassification.Trainers.GeneralizedAdditiveModels(); + + // Train. + var model = pipeline.Fit(data); + + // Save and reload. + string modelPath = GetOutputPath(FullTestName + "-model.zip"); + using (var fs = File.Create(modelPath)) + ml.Model.Save(loader, model, fs); + + IDataLoader loadedModel; + ITransformer loadedModelWithoutLoader; + using (var fs = File.OpenRead(modelPath)) + { + loadedModel = ml.Model.LoadAsCompositeDataLoader(fs); + loadedModelWithoutLoader = ml.Model.Load(fs); + } + + // Without deserializing the loader from the model we lose the slot names. + data = ml.Data.LoadFromEnumerable(new[] { new InputData() }); + data = loadedModelWithoutLoader.Transform(data); + Assert.Null(data.Schema["Features"].Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.SlotNames)); + + data = loadedModel.Load(file); + Assert.True(data.Schema["Features"].HasSlotNames(data.Schema["Features"].Type.GetValueCount())); + VBuffer> slotNames = default; + data.Schema["Features"].GetSlotNames(ref slotNames); + var ageIndex = FindIndex(slotNames.GetValues(), "age"); + var transformer = (loadedModel as CompositeDataLoader).Transformer.LastTransformer; + var gamModel = ((transformer as BinaryPredictionTransformer).Model + as CalibratedModelParametersBase).SubModel + as BinaryClassificationGamModelParameters; + var ageBinUpperBounds = gamModel.GetBinUpperBounds(ageIndex); + var ageBinEffects = gamModel.GetBinEffects(ageIndex); + } + + private int FindIndex(ReadOnlySpan> values, string slotName) + { + int index = 0; + foreach (var value in values) + { + if (value.Span.SequenceEqual(slotName.AsSpan())) + return index; + index++; + } + return -1; + } + } +} From 49972021dac769b922fe552088bbb43e8bfc30cd Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Tue, 5 Mar 2019 15:10:38 -0800 Subject: [PATCH 02/18] Address some code review comments, add a non-generic base class for calibrated predictor --- .../DataLoadSave/CompositeDataLoader.cs | 23 ------------------ .../Model/ModelOperationsCatalog.cs | 4 ++-- .../Prediction/Calibrator.cs | 24 +++++++++++++++---- .../Api/Estimators/DeserializationTests.cs | 4 ++-- 4 files changed, 23 insertions(+), 32 deletions(-) diff --git a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs index 98fb04f309..0b47bd4ff4 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs @@ -98,29 +98,6 @@ void ICanSaveModel.Save(ModelSaveContext ctx) ctx.SaveModel(Transformer, TransformerDirectory); } - /// - /// Save the contents to a stream, as a "model file". - /// - public void SaveTo(IHostEnvironment env, Stream outputStream) - { - Contracts.CheckValue(env, nameof(env)); - env.CheckValue(outputStream, nameof(outputStream)); - - env.Check(outputStream.CanWrite && outputStream.CanSeek, "Need a writable and seekable stream to save"); - using (var ch = env.Start("Saving pipeline")) - { - using (var rep = RepositoryWriter.CreateNew(outputStream, ch)) - { - ch.Trace("Saving data loader"); - ModelSaveContext.SaveModel(rep, Loader, LoaderDirectory); - - ch.Trace("Saving transformer chain"); - ModelSaveContext.SaveModel(rep, Transformer, TransformerDirectory); - rep.Commit(); - } - } - } - internal const string Summary = "A loader that encapsulates a loader and a transformer chain."; internal const string LoaderSignature = "CompositeLoader"; diff --git a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs index 28fadf1c7c..3e80f37cc2 100644 --- a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs +++ b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs @@ -52,11 +52,11 @@ public void Save(IDataLoader loader, ITransformer model, Strea /// The loaded model. public ITransformer Load(Stream stream) => TransformerChain.LoadFrom(_env, stream); - public IDataLoader LoadAsCompositeDataLoader(Stream stream) + public CompositeDataLoader LoadAsCompositeDataLoader(Stream stream) { using (var rep = RepositoryReader.Open(stream)) { - ModelLoadContext.LoadModel, SignatureLoadModel>(_env, out var model, rep, "Model"); + ModelLoadContext.LoadModel, SignatureLoadModel>(_env, out var model, rep, "Model"); return model; } } diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index 8c9ffdc08c..c6bedc6cec 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -58,7 +58,7 @@ "Naive Calibration Executor", NaiveCalibrator.LoaderSignature)] -[assembly: LoadableClass(typeof(ValueMapperCalibratedModelParameters), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(CalibratedModelParametersBase), typeof(ValueMapperCalibratedModelParameters, ICalibrator>), null, typeof(SignatureLoadModel), "Calibrated Predictor Executor", ValueMapperCalibratedModelParameters, ICalibrator>.LoaderSignature, "BulkCaliPredExec")] @@ -151,6 +151,18 @@ internal interface IWeaklyTypedCalibratedModelParameters ICalibrator WeaklyTypedCalibrator { get; } } + public abstract class CalibratedModelParametersBase + { + public object SubModel { get; } + public ICalibrator Calibrator { get; } + + private protected CalibratedModelParametersBase(object subModel, ICalibrator calibrator) + { + SubModel = subModel; + Calibrator = calibrator; + } + } + /// /// Class for allowing a post-processing step, defined by , to 's /// output. @@ -162,7 +174,7 @@ internal interface IWeaklyTypedCalibratedModelParameters /// output value to the probability of belonging to the positive (or negative) class. Detailed math materials /// can be found at this paper. /// - public abstract class CalibratedModelParametersBase : + public abstract class CalibratedModelParametersBase : CalibratedModelParametersBase, IDistPredictorProducing, ICanSaveInIniFormat, ICanSaveInTextFormat, @@ -179,11 +191,12 @@ public abstract class CalibratedModelParametersBase : /// /// 's output would calibrated by . /// - public TSubModel SubModel { get; } + public new TSubModel SubModel { get; } + /// /// is used to post-process score produced by . /// - public TCalibrator Calibrator { get; } + public new TCalibrator Calibrator { get; } // Type-unsafed accessors of strongly-typed members. IPredictorProducing IWeaklyTypedCalibratedModelParameters.WeaklyTypedSubModel => (IPredictorProducing)SubModel; @@ -192,6 +205,7 @@ public abstract class CalibratedModelParametersBase : PredictionKind IPredictor.PredictionKind => ((IPredictorProducing)SubModel).PredictionKind; private protected CalibratedModelParametersBase(IHostEnvironment env, string name, TSubModel predictor, TCalibrator calibrator) + : base(predictor, calibrator) { Contracts.CheckValue(env, nameof(env)); env.CheckNonWhiteSpace(name, nameof(name)); @@ -418,7 +432,7 @@ private ValueMapperCalibratedModelParameters(IHostEnvironment env, ModelLoadCont { } - private static ValueMapperCalibratedModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) + private static CalibratedModelParametersBase Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(ctx, nameof(ctx)); // Can load either the old "bulk" model or standard "cali". The two formats are identical. diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DeserializationTests.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DeserializationTests.cs index 07718e8e01..2ee0da42bf 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DeserializationTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DeserializationTests.cs @@ -49,7 +49,7 @@ public void LoadModelAndExtractPredictor() var gam = (((loadedModel as TransformerChain).LastTransformer as BinaryPredictionTransformer).Model - as CalibratedModelParametersBase).SubModel + as CalibratedModelParametersBase).SubModel as BinaryClassificationGamModelParameters; Assert.NotNull(gam); } @@ -93,7 +93,7 @@ public void SaveAndLoadModelWithLoader() var ageIndex = FindIndex(slotNames.GetValues(), "age"); var transformer = (loadedModel as CompositeDataLoader).Transformer.LastTransformer; var gamModel = ((transformer as BinaryPredictionTransformer).Model - as CalibratedModelParametersBase).SubModel + as CalibratedModelParametersBase).SubModel as BinaryClassificationGamModelParameters; var ageBinUpperBounds = gamModel.GetBinUpperBounds(ageIndex); var ageBinEffects = gamModel.GetBinEffects(ageIndex); From f4f322693dc99a4b5d98af6b45264f7619f2edb8 Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Thu, 7 Mar 2019 15:08:00 -0800 Subject: [PATCH 03/18] use the contravariance of ISingleFeaturePredictionTransformer instead of loading PredictionTransformer from file --- .../Prediction/Calibrator.cs | 12 ++-- .../Prediction/IPredictionTransformer.cs | 2 + .../Scorers/PredictionTransformer.cs | 58 ++++++++++--------- .../PermutationFeatureImportance.cs | 1 + .../PermutationFeatureImportanceExtensions.cs | 8 +-- .../Api/Estimators/DeserializationTests.cs | 11 ++-- 6 files changed, 50 insertions(+), 42 deletions(-) diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index c6bedc6cec..2cf25b5ebb 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -62,15 +62,15 @@ "Calibrated Predictor Executor", ValueMapperCalibratedModelParameters, ICalibrator>.LoaderSignature, "BulkCaliPredExec")] -[assembly: LoadableClass(typeof(FeatureWeightsCalibratedModelParameters), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(CalibratedModelParametersBase), typeof(FeatureWeightsCalibratedModelParameters, ICalibrator>), null, typeof(SignatureLoadModel), "Feature Weights Calibrated Predictor Executor", FeatureWeightsCalibratedModelParameters, ICalibrator>.LoaderSignature)] -[assembly: LoadableClass(typeof(ParameterMixingCalibratedModelParameters), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(CalibratedModelParametersBase), typeof(ParameterMixingCalibratedModelParameters, ICalibrator>), null, typeof(SignatureLoadModel), "Parameter Mixing Calibrated Predictor Executor", ParameterMixingCalibratedModelParameters, ICalibrator>.LoaderSignature)] -[assembly: LoadableClass(typeof(SchemaBindableCalibratedModelParameters), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(CalibratedModelParametersBase), typeof(SchemaBindableCalibratedModelParameters, ICalibrator>), null, typeof(SignatureLoadModel), "Schema Bindable Calibrated Predictor", SchemaBindableCalibratedModelParameters, ICalibrator>.LoaderSignature)] [assembly: LoadableClass(typeof(void), typeof(Calibrate), null, typeof(SignatureEntryPointModule), "Calibrate")] @@ -491,7 +491,7 @@ private FeatureWeightsCalibratedModelParameters(IHostEnvironment env, ModelLoadC _featureWeights = (IPredictorWithFeatureWeights)SubModel; } - private static FeatureWeightsCalibratedModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) + private static CalibratedModelParametersBase Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); @@ -558,7 +558,7 @@ private ParameterMixingCalibratedModelParameters(IHostEnvironment env, ModelLoad _featureWeights = SubModel as IPredictorWithFeatureWeights; } - private static ParameterMixingCalibratedModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) + private static CalibratedModelParametersBase Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); @@ -725,7 +725,7 @@ private SchemaBindableCalibratedModelParameters(IHostEnvironment env, ModelLoadC _featureContribution = SubModel as IFeatureContributionMapper; } - private static SchemaBindableCalibratedModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) + private static CalibratedModelParametersBase Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); diff --git a/src/Microsoft.ML.Data/Prediction/IPredictionTransformer.cs b/src/Microsoft.ML.Data/Prediction/IPredictionTransformer.cs index b3a342de52..fb2251fce8 100644 --- a/src/Microsoft.ML.Data/Prediction/IPredictionTransformer.cs +++ b/src/Microsoft.ML.Data/Prediction/IPredictionTransformer.cs @@ -14,6 +14,7 @@ namespace Microsoft.ML /// /// The or used for the data transformation. public interface IPredictionTransformer : ITransformer + where TModel : class { TModel Model { get; } } @@ -25,6 +26,7 @@ public interface IPredictionTransformer : ITransformer /// /// The or used for the data transformation. public interface ISingleFeaturePredictionTransformer : IPredictionTransformer + where TModel : class { /// The name of the feature column. string FeatureColumn { get; } diff --git a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs index 5836de85da..399eed9107 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs @@ -8,22 +8,22 @@ using Microsoft.ML.Data.IO; using Microsoft.ML.Runtime; -[assembly: LoadableClass(typeof(BinaryPredictionTransformer), typeof(BinaryPredictionTransformer), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(BinaryPredictionTransformer>), typeof(BinaryPredictionTransformer), null, typeof(SignatureLoadModel), "", BinaryPredictionTransformer.LoaderSignature)] -[assembly: LoadableClass(typeof(MulticlassPredictionTransformer), typeof(MulticlassPredictionTransformer), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(MulticlassPredictionTransformer>>), typeof(MulticlassPredictionTransformer), null, typeof(SignatureLoadModel), "", MulticlassPredictionTransformer.LoaderSignature)] -[assembly: LoadableClass(typeof(RegressionPredictionTransformer), typeof(RegressionPredictionTransformer), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(RegressionPredictionTransformer>), typeof(RegressionPredictionTransformer), null, typeof(SignatureLoadModel), "", RegressionPredictionTransformer.LoaderSignature)] -[assembly: LoadableClass(typeof(RankingPredictionTransformer), typeof(RankingPredictionTransformer), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(RankingPredictionTransformer>), typeof(RankingPredictionTransformer), null, typeof(SignatureLoadModel), "", RankingPredictionTransformer.LoaderSignature)] -[assembly: LoadableClass(typeof(AnomalyPredictionTransformer), typeof(AnomalyPredictionTransformer), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(AnomalyPredictionTransformer>), typeof(AnomalyPredictionTransformer), null, typeof(SignatureLoadModel), "", AnomalyPredictionTransformer.LoaderSignature)] -[assembly: LoadableClass(typeof(ClusteringPredictionTransformer), typeof(ClusteringPredictionTransformer), null, typeof(SignatureLoadModel), +[assembly: LoadableClass(typeof(ClusteringPredictionTransformer>>), typeof(ClusteringPredictionTransformer), null, typeof(SignatureLoadModel), "", ClusteringPredictionTransformer.LoaderSignature)] namespace Microsoft.ML.Data @@ -51,7 +51,8 @@ public abstract class PredictionTransformerBase : IPredictionTransformer private protected readonly IHost Host; [BestFriend] private protected ISchemaBindableMapper BindableMapper; - protected DataViewSchema TrainSchema; + [BestFriend] + private protected DataViewSchema TrainSchema; /// /// Whether a call to should succeed, on an @@ -141,7 +142,8 @@ IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema) private protected abstract void SaveModel(ModelSaveContext ctx); - protected void SaveModelCore(ModelSaveContext ctx) + [BestFriend] + private protected void SaveModelCore(ModelSaveContext ctx) { // *** Binary format *** // @@ -233,14 +235,14 @@ public sealed override DataViewSchema GetOutputSchema(DataViewSchema inputSchema return Transform(new EmptyDataView(Host, inputSchema)).Schema; } - private protected override void SaveModel(ModelSaveContext ctx) + private protected sealed override void SaveModel(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(); SaveCore(ctx); } - protected virtual void SaveCore(ModelSaveContext ctx) + private protected virtual void SaveCore(ModelSaveContext ctx) { SaveModelCore(ctx); ctx.SaveStringOrNull(FeatureColumn); @@ -295,7 +297,7 @@ private void SetScorer() Scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.SetVersionInfo(GetVersionInfo()); @@ -364,7 +366,7 @@ private void SetScorer() Scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.SetVersionInfo(GetVersionInfo()); @@ -428,7 +430,7 @@ private void SetScorer() Scorer = new MulticlassClassificationScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.SetVersionInfo(GetVersionInfo()); @@ -473,7 +475,7 @@ internal RegressionPredictionTransformer(IHostEnvironment env, ModelLoadContext Scorer = GetGenericScorer(); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.SetVersionInfo(GetVersionInfo()); @@ -515,7 +517,7 @@ internal RankingPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx Scorer = GetGenericScorer(); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.SetVersionInfo(GetVersionInfo()); @@ -567,7 +569,7 @@ internal ClusteringPredictionTransformer(IHostEnvironment env, ModelLoadContext Scorer = new ClusteringScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); } - protected override void SaveCore(ModelSaveContext ctx) + private protected override void SaveCore(ModelSaveContext ctx) { Contracts.AssertValue(ctx); ctx.SetVersionInfo(GetVersionInfo()); @@ -594,47 +596,47 @@ internal static class BinaryPredictionTransformer { public const string LoaderSignature = "BinaryPredXfer"; - public static BinaryPredictionTransformer Create(IHostEnvironment env, ModelLoadContext ctx) - => new BinaryPredictionTransformer(env, ctx); + public static BinaryPredictionTransformer> Create(IHostEnvironment env, ModelLoadContext ctx) + => new BinaryPredictionTransformer>(env, ctx); } internal static class MulticlassPredictionTransformer { public const string LoaderSignature = "MulticlassPredXfer"; - public static MulticlassPredictionTransformer Create(IHostEnvironment env, ModelLoadContext ctx) - => new MulticlassPredictionTransformer(env, ctx); + public static MulticlassPredictionTransformer>> Create(IHostEnvironment env, ModelLoadContext ctx) + => new MulticlassPredictionTransformer>>(env, ctx); } internal static class RegressionPredictionTransformer { public const string LoaderSignature = "RegressionPredXfer"; - public static RegressionPredictionTransformer Create(IHostEnvironment env, ModelLoadContext ctx) - => new RegressionPredictionTransformer(env, ctx); + public static RegressionPredictionTransformer> Create(IHostEnvironment env, ModelLoadContext ctx) + => new RegressionPredictionTransformer>(env, ctx); } internal static class RankingPredictionTransformer { public const string LoaderSignature = "RankingPredXfer"; - public static RankingPredictionTransformer Create(IHostEnvironment env, ModelLoadContext ctx) - => new RankingPredictionTransformer(env, ctx); + public static RankingPredictionTransformer> Create(IHostEnvironment env, ModelLoadContext ctx) + => new RankingPredictionTransformer>(env, ctx); } internal static class AnomalyPredictionTransformer { public const string LoaderSignature = "AnomalyPredXfer"; - public static AnomalyPredictionTransformer Create(IHostEnvironment env, ModelLoadContext ctx) - => new AnomalyPredictionTransformer(env, ctx); + public static AnomalyPredictionTransformer> Create(IHostEnvironment env, ModelLoadContext ctx) + => new AnomalyPredictionTransformer>(env, ctx); } internal static class ClusteringPredictionTransformer { public const string LoaderSignature = "ClusteringPredXfer"; - public static ClusteringPredictionTransformer Create(IHostEnvironment env, ModelLoadContext ctx) - => new ClusteringPredictionTransformer(env, ctx); + public static ClusteringPredictionTransformer>> Create(IHostEnvironment env, ModelLoadContext ctx) + => new ClusteringPredictionTransformer>>(env, ctx); } } diff --git a/src/Microsoft.ML.Transforms/PermutationFeatureImportance.cs b/src/Microsoft.ML.Transforms/PermutationFeatureImportance.cs index 16b0925174..1bd034ff74 100644 --- a/src/Microsoft.ML.Transforms/PermutationFeatureImportance.cs +++ b/src/Microsoft.ML.Transforms/PermutationFeatureImportance.cs @@ -15,6 +15,7 @@ namespace Microsoft.ML.Transforms { internal static class PermutationFeatureImportance where TResult : IMetricsStatistics + where TModel : class { public static ImmutableArray GetImportanceMetricsMatrix( diff --git a/src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs b/src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs index afae7e7f1b..c66a7ed0de 100644 --- a/src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs +++ b/src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs @@ -63,7 +63,7 @@ public static ImmutableArray string features = DefaultColumnNames.Features, bool useFeatureWeightFilter = false, int? topExamples = null, - int permutationCount = 1) + int permutationCount = 1) where TModel : class { return PermutationFeatureImportance.GetImportanceMetricsMatrix( catalog.GetEnvironment(), @@ -141,7 +141,7 @@ public static ImmutableArray string features = DefaultColumnNames.Features, bool useFeatureWeightFilter = false, int? topExamples = null, - int permutationCount = 1) + int permutationCount = 1) where TModel : class { return PermutationFeatureImportance.GetImportanceMetricsMatrix( catalog.GetEnvironment(), @@ -216,7 +216,7 @@ public static ImmutableArray string features = DefaultColumnNames.Features, bool useFeatureWeightFilter = false, int? topExamples = null, - int permutationCount = 1) + int permutationCount = 1) where TModel : class { return PermutationFeatureImportance.GetImportanceMetricsMatrix( catalog.GetEnvironment(), @@ -298,7 +298,7 @@ public static ImmutableArray string features = DefaultColumnNames.Features, bool useFeatureWeightFilter = false, int? topExamples = null, - int permutationCount = 1) + int permutationCount = 1) where TModel : class { return PermutationFeatureImportance.GetImportanceMetricsMatrix( catalog.GetEnvironment(), diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DeserializationTests.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DeserializationTests.cs index 2ee0da42bf..ef901b5d3d 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DeserializationTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DeserializationTests.cs @@ -48,7 +48,7 @@ public void LoadModelAndExtractPredictor() loadedModel = ml.Model.Load(fs); var gam = (((loadedModel as TransformerChain).LastTransformer - as BinaryPredictionTransformer).Model + as ISingleFeaturePredictionTransformer).Model as CalibratedModelParametersBase).SubModel as BinaryClassificationGamModelParameters; Assert.NotNull(gam); @@ -92,9 +92,12 @@ public void SaveAndLoadModelWithLoader() data.Schema["Features"].GetSlotNames(ref slotNames); var ageIndex = FindIndex(slotNames.GetValues(), "age"); var transformer = (loadedModel as CompositeDataLoader).Transformer.LastTransformer; - var gamModel = ((transformer as BinaryPredictionTransformer).Model - as CalibratedModelParametersBase).SubModel - as BinaryClassificationGamModelParameters; + var singleFeaturePredictionTransformer = transformer as ISingleFeaturePredictionTransformer; + Assert.NotNull(singleFeaturePredictionTransformer); + var calibratedModelParameters = singleFeaturePredictionTransformer.Model as CalibratedModelParametersBase; + Assert.NotNull(calibratedModelParameters); + var gamModel = calibratedModelParameters.SubModel as BinaryClassificationGamModelParameters; + Assert.NotNull(gamModel); var ageBinUpperBounds = gamModel.GetBinUpperBounds(ageIndex); var ageBinEffects = gamModel.GetBinEffects(ageIndex); } From b1de318c40cc8fb7595af9750b2ab67e1cf5e1ee Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Mon, 11 Mar 2019 11:05:42 -0700 Subject: [PATCH 04/18] Add API for saving/loading input schema --- .../DataLoadSave/CompositeDataLoader.cs | 14 ++-- .../DataLoadSave/DataOperationsCatalog.cs | 8 ++ .../DataLoadSave/TransformerChain.cs | 6 +- .../DataView/DataViewConstructionUtils.cs | 48 ++++++++++- .../Model/ModelOperationsCatalog.cs | 84 +++++++++++++++++-- .../UnitTests/TestEntryPoints.cs | 2 +- .../DataPipe/TestDataPipeBase.cs | 8 +- test/Microsoft.ML.Tests/ImagesTests.cs | 5 +- .../Api/CookbookSamples/CookbookSamples.cs | 4 +- .../CookbookSamplesDynamicApi.cs | 8 +- .../Api/Estimators/DeserializationTests.cs | 13 +-- .../MatrixFactorizationTests.cs | 6 +- .../Transformers/ConvertTests.cs | 2 +- 13 files changed, 171 insertions(+), 37 deletions(-) diff --git a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs index 0b47bd4ff4..f7869de3d6 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs @@ -19,10 +19,6 @@ namespace Microsoft.ML.Data public sealed class CompositeDataLoader : IDataLoader where TLastTransformer : class, ITransformer { - private const string LoaderDirectory = "Loader"; - private const string LegacyLoaderDirectory = "Reader"; - private const string TransformerDirectory = TransformerChain.LoaderSignature; - /// /// The underlying data loader. /// @@ -43,9 +39,9 @@ public CompositeDataLoader(IDataLoader loader, TransformerChain, SignatureLoadModel>(host, out Loader, LegacyLoaderDirectory)) - ctx.LoadModel, SignatureLoadModel>(host, out Loader, LoaderDirectory); - ctx.LoadModel, SignatureLoadModel>(host, out Transformer, TransformerDirectory); + if (!ctx.LoadModelOrNull, SignatureLoadModel>(host, out Loader, ModelOperationsCatalog.LegacyLoaderDirectory)) + ctx.LoadModel, SignatureLoadModel>(host, out Loader, ModelOperationsCatalog.LoaderDirectory); + ctx.LoadModel, SignatureLoadModel>(host, out Transformer, ModelOperationsCatalog.TransformerDirectory); } private static CompositeDataLoader Create(IHostEnvironment env, ModelLoadContext ctx) @@ -94,8 +90,8 @@ void ICanSaveModel.Save(ModelSaveContext ctx) ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); - ctx.SaveModel(Loader, LoaderDirectory); - ctx.SaveModel(Transformer, TransformerDirectory); + ctx.SaveModel(Loader, ModelOperationsCatalog.LoaderDirectory); + ctx.SaveModel(Transformer, ModelOperationsCatalog.TransformerDirectory); } internal const string Summary = "A loader that encapsulates a loader and a transformer chain."; diff --git a/src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs b/src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs index 2da1fb2e98..098ee8c59a 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs @@ -82,6 +82,14 @@ public IDataView LoadFromEnumerable(IEnumerable data, SchemaDefiniti return DataViewConstructionUtils.CreateFromEnumerable(_env, data, schemaDefinition); } + public IDataView LoadFromEnumerable(IEnumerable data, DataViewSchema schema) + where TRow : class + { + _env.CheckValue(data, nameof(data)); + _env.CheckValue(schema, nameof(schema)); + return DataViewConstructionUtils.CreateFromEnumerable(_env, data, schema); + } + /// /// Convert an into a strongly-typed . /// diff --git a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs index 908e910eae..a5f7bc621e 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs @@ -250,15 +250,15 @@ private static TransformerChain Create(IHostEnvironment env, Model public static void SaveTo(this ITransformer transformer, IHostEnvironment env, Stream outputStream) => new TransformerChain(transformer).SaveTo(env, outputStream); - public static TransformerChain LoadFrom(IHostEnvironment env, Stream stream) + public static ITransformer LoadFrom(IHostEnvironment env, Stream stream) { using (var rep = RepositoryReader.Open(stream, env)) { try { - ModelLoadContext.LoadModelOrNull, SignatureLoadModel>(env, out var transformerChain, rep, LoaderSignature); + ModelLoadContext.LoadModelOrNull(env, out var transformerChain, rep, LoaderSignature); if (transformerChain == null) - ModelLoadContext.LoadModel, SignatureLoadModel>(env, out transformerChain, rep, $@"Model\{LoaderSignature}"); + ModelLoadContext.LoadModel(env, out transformerChain, rep, $@"Model\{LoaderSignature}"); return transformerChain; } catch (FormatException ex) diff --git a/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs b/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs index bc8d234bc2..8ed0dbfa61 100644 --- a/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs @@ -45,6 +45,50 @@ public static StreamingDataView CreateFromEnumerable(IHostEnvironmen return new StreamingDataView(env, data, internalSchemaDefn); } + public static StreamingDataView CreateFromEnumerable(IHostEnvironment env, IEnumerable data, + DataViewSchema schema) + where TRow : class + { + Contracts.AssertValue(env); + env.AssertValue(data); + env.AssertValueOrNull(schema); + schema = schema ?? new DataViewSchema.Builder().ToSchema(); + return new StreamingDataView(env, data, GetInternalSchemaDefinition(env, schema)); + } + + private static InternalSchemaDefinition GetInternalSchemaDefinition(IHostEnvironment env, DataViewSchema schema) + where TRow : class + { + Contracts.AssertValue(env); + env.AssertValue(schema); + + var isd = InternalSchemaDefinition.Create(typeof(TRow), SchemaDefinition.Direction.Read); + foreach (var col in schema) + { + var name = col.Name; + var isdCol = isd.Columns.FirstOrDefault(c => c.ColumnName == name); + if (isdCol == null) + throw env.Except($"Type should contain a member named {isdCol.ColumnName}"); + var annotations = col.Annotations; + if (annotations != null) + { + foreach (var annotation in annotations.Schema) + { + var info = Utils.MarshalInvoke(GetAnnotationInfo, annotation.Type.RawType, annotation.Name, annotations); + isdCol.Annotations.Add(annotation.Name, info); + } + } + } + return isd; + } + + private static AnnotationInfo GetAnnotationInfo(string kind, DataViewSchema.Annotations annotations) + { + T value = default; + annotations.GetValue(kind, ref value); + return new AnnotationInfo(kind, value); + } + public static InputRow CreateInputRow(IHostEnvironment env, SchemaDefinition schemaDefinition = null) where TRow : class { @@ -626,7 +670,7 @@ public StreamingDataView(IHostEnvironment env, IEnumerable data, InternalS public override DataViewRowCursor GetRowCursor(IEnumerable columnsNeeded, Random rand = null) { var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, Schema); - return new WrappedCursor (new Cursor(Host, this, predicate)); + return new WrappedCursor(new Cursor(Host, this, predicate)); } private sealed class Cursor : DataViewCursorBase @@ -696,7 +740,7 @@ public override DataViewRowCursor GetRowCursor(IEnumerable public sealed class ModelOperationsCatalog : IInternalCatalog { + internal const string LoaderDirectory = "Loader"; + internal const string LegacyLoaderDirectory = "Reader"; + internal const string TransformerDirectory = TransformerChain.LoaderSignature; + internal const string SchemaEntryName = "Schema"; + IHostEnvironment IInternalCatalog.Environment => _env; private readonly IHostEnvironment _env; @@ -31,28 +40,93 @@ internal ModelOperationsCatalog(IHostEnvironment env) /// /// The trained model to be saved. /// A writeable, seekable stream to save to. - public void Save(ITransformer model, Stream stream) => model.SaveTo(_env, stream); - public void Save(IDataLoader model, Stream stream) { using (var rep = RepositoryWriter.CreateNew(stream)) { ModelSaveContext.SaveModel(rep, model, "Model"); + SaveInputSchema(model.GetOutputSchema(), rep); rep.Commit(); } } + /// + /// Save a transformer model and the loader used to create its input data to the stream. + /// + /// The loader that was used to create data to train the model + /// The trained model to be saved + /// A writeable, seekable stream to save to. public void Save(IDataLoader loader, ITransformer model, Stream stream) => Save(new CompositeDataLoader(loader, new TransformerChain(model)), stream); /// - /// Load the model from the stream. + /// Save a transformer model and the schema of the data that was used to train it to the stream. + /// + /// The schema of the input to the transformer. + /// The trained model to be saved. + /// A writeable, seekable stream to save to. + public void Save(DataViewSchema inputSchema, ITransformer model, Stream stream) + { + using (var rep = RepositoryWriter.CreateNew(stream)) + { + ModelSaveContext.SaveModel(rep, model, TransformerDirectory); + SaveInputSchema(inputSchema, rep); + rep.Commit(); + } + } + + private void SaveInputSchema(DataViewSchema inputSchema, RepositoryWriter rep) + { + using (var ch = _env.Start("Saving Schema")) + { + var entry = rep.CreateEntry(SchemaEntryName); + var saver = new BinarySaver(_env, new BinarySaver.Arguments { Silent = true }); + DataSaverUtils.SaveDataView(ch, saver, new EmptyDataView(_env, inputSchema), entry.Stream, keepHidden: true); + } + } + + /// + /// Load the model and its input schema from the stream. /// /// A readable, seekable stream to load from. + /// Will contain the input schema for the model. /// The loaded model. - public ITransformer Load(Stream stream) => TransformerChain.LoadFrom(_env, stream); + public ITransformer Load(Stream stream, out DataViewSchema inputSchema) + { + using (var rep = RepositoryReader.Open(stream, _env)) + { + var entry = rep.OpenEntryOrNull(SchemaEntryName); + if (entry != null) + { + var loader = new BinaryLoader(_env, new BinaryLoader.Arguments(), entry.Stream); + inputSchema = loader.Schema; + } + else + { + // Try to load from legacy model format. + try + { + var loader = ModelFileUtils.LoadLoader(_env, rep, new MultiFileSource(null), false); + inputSchema = loader.Schema; + } + catch (Exception ex) + { + if (!ex.IsMarked()) + throw; + inputSchema = null; + } + } + return TransformerChain.LoadFrom(_env, stream); + } + } - public CompositeDataLoader LoadAsCompositeDataLoader(Stream stream) + /// + /// Load the model and its input schema from the stream. + /// + /// A readable, seekable stream to load from. + /// A model of type containing the loader + /// and the transformer chain. + public CompositeDataLoader Load(Stream stream) { using (var rep = RepositoryReader.Open(stream)) { diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index 85b15336e0..8afe9f106b 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -5645,7 +5645,7 @@ public void LoadEntryPointModel() ITransformer loadedModel; using (var stream = File.OpenRead(modelPath)) { - loadedModel = ml.Model.Load(stream); + loadedModel = ml.Model.Load(stream, out var inputSchema); } } diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs index a7daf9b985..166c5e8b0a 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs @@ -80,10 +80,12 @@ protected void TestEstimatorCore(IEstimator estimator, // Save and reload. string modelPath = GetOutputPath(FullTestName + "-model.zip"); using (var fs = File.Create(modelPath)) - ML.Model.Save(transformer, fs); + ML.Model.Save(validFitInput.Schema, transformer, fs); ITransformer loadedTransformer; - loadedTransformer = ML.Model.Load(modelPath); + DataViewSchema loadedInputSchema; + using (var fs = File.OpenRead(modelPath)) + loadedTransformer = ML.Model.Load(fs, out loadedInputSchema); DeleteOutputPath(modelPath); // Run on train data. @@ -105,6 +107,8 @@ protected void TestEstimatorCore(IEstimator estimator, // Loaded transformer needs to have the same schema propagation. CheckSameSchemas(schema, loadedTransformer.GetOutputSchema(data.Schema)); + // Loaded schema needs to have the same schema as data. + CheckSameSchemas(data.Schema, loadedInputSchema); var scoredTrain = transformer.Transform(data); var scoredTrain2 = loadedTransformer.Transform(data); diff --git a/test/Microsoft.ML.Tests/ImagesTests.cs b/test/Microsoft.ML.Tests/ImagesTests.cs index 515fa98d30..559e4a6f05 100644 --- a/test/Microsoft.ML.Tests/ImagesTests.cs +++ b/test/Microsoft.ML.Tests/ImagesTests.cs @@ -83,7 +83,10 @@ public void TestEstimatorSaveLoad() model.SaveTo(env, fs); var model2 = TransformerChain.LoadFrom(env, file.OpenReadStream()); - var newCols = ((ImageLoadingTransformer)model2.First()).Columns; + var transformerChain = model2 as TransformerChain; + Assert.NotNull(transformerChain); + + var newCols = ((ImageLoadingTransformer)transformerChain.First()).Columns; var oldCols = ((ImageLoadingTransformer)model.First()).Columns; Assert.True(newCols .Zip(oldCols, (x, y) => x == y) diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs index b3eb49a45c..8b9f8b7a8b 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs @@ -149,7 +149,7 @@ private void TrainRegression(string trainDataPath, string testDataPath, string m using (var stream = File.Create(modelPath)) { // Saving and loading happens to 'dynamic' models, so the static typing is lost in the process. - mlContext.Model.Save(model.AsDynamic, stream); + mlContext.Model.Save(trainData.AsDynamic.Schema, model.AsDynamic, stream); } // Potentially, the lines below can be in a different process altogether. @@ -157,7 +157,7 @@ private void TrainRegression(string trainDataPath, string testDataPath, string m // When you load the model, it's a 'dynamic' transformer. ITransformer loadedModel; using (var stream = File.OpenRead(modelPath)) - loadedModel = mlContext.Model.Load(stream); + loadedModel = mlContext.Model.Load(stream, out var schema); } [Fact] diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs index 3d0fb02ccb..fbb704cabe 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs @@ -123,7 +123,7 @@ private void TrainRegression(string trainDataPath, string testDataPath, string m using (var stream = File.Create(modelPath)) { // Saving and loading happens to 'dynamic' models. - mlContext.Model.Save(model, stream); + mlContext.Model.Save(trainData.Schema, model, stream); } // Potentially, the lines below can be in a different process altogether. @@ -131,7 +131,7 @@ private void TrainRegression(string trainDataPath, string testDataPath, string m // When you load the model, it's a 'dynamic' transformer. ITransformer loadedModel; using (var stream = File.OpenRead(modelPath)) - loadedModel = mlContext.Model.Load(stream); + loadedModel = mlContext.Model.Load(stream, out var schema); } [Fact] @@ -523,7 +523,7 @@ private static void RunEndToEnd(MLContext mlContext, IDataView trainData, string // Save the model. using (var fs = File.Create(modelPath)) - mlContext.Model.Save(model, fs); + mlContext.Model.Save(cachedTrainData.Schema, model, fs); // Now pretend we are in a different process. var newContext = new MLContext(); @@ -535,7 +535,7 @@ private static void RunEndToEnd(MLContext mlContext, IDataView trainData, string // Now we can load the model. ITransformer loadedModel; using (var fs = File.OpenRead(modelPath)) - loadedModel = newContext.Model.Load(fs); + loadedModel = newContext.Model.Load(fs, out var schema); } public static IDataView PrepareData(MLContext mlContext, IDataView data) diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DeserializationTests.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DeserializationTests.cs index ef901b5d3d..c0ae33d498 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DeserializationTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DeserializationTests.cs @@ -5,6 +5,7 @@ using System; using System.IO; using System.Linq; +using Microsoft.Data.DataView; using Microsoft.ML.Calibrators; using Microsoft.ML.Data; using Microsoft.ML.RunTests; @@ -41,14 +42,13 @@ public void LoadModelAndExtractPredictor() // Save and reload. string modelPath = GetOutputPath(FullTestName + "-model.zip"); using (var fs = File.Create(modelPath)) - ml.Model.Save(model, fs); + ml.Model.Save(data.Schema, model, fs); ITransformer loadedModel; using (var fs = File.OpenRead(modelPath)) - loadedModel = ml.Model.Load(fs); + loadedModel = ml.Model.Load(fs, out var loadedSchema); - var gam = (((loadedModel as TransformerChain).LastTransformer - as ISingleFeaturePredictionTransformer).Model + var gam = ((loadedModel as ISingleFeaturePredictionTransformer).Model as CalibratedModelParametersBase).SubModel as BinaryClassificationGamModelParameters; Assert.NotNull(gam); @@ -75,10 +75,11 @@ public void SaveAndLoadModelWithLoader() IDataLoader loadedModel; ITransformer loadedModelWithoutLoader; + DataViewSchema loadedSchema; using (var fs = File.OpenRead(modelPath)) { - loadedModel = ml.Model.LoadAsCompositeDataLoader(fs); - loadedModelWithoutLoader = ml.Model.Load(fs); + loadedModel = ml.Model.Load(fs); + loadedModelWithoutLoader = ml.Model.Load(fs, out loadedSchema); } // Without deserializing the loader from the model we lose the slot names. diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs index f0e6ed6b5e..bba3945931 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs @@ -504,7 +504,11 @@ public void MatrixFactorizationBackCompat() using (var ch = Env.Start("load")) { using (var fs = File.OpenRead(modelPath)) - model = ML.Model.Load(fs); + { + model = ML.Model.Load(fs, out var schema); + // This model was saved without the input schema. + Assert.Null(schema); + } } // Create data for testing. Note that the 2nd element is not specified in the training data so it should diff --git a/test/Microsoft.ML.Tests/Transformers/ConvertTests.cs b/test/Microsoft.ML.Tests/Transformers/ConvertTests.cs index 46ffbdb031..e0167a241e 100644 --- a/test/Microsoft.ML.Tests/Transformers/ConvertTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/ConvertTests.cs @@ -261,7 +261,7 @@ public void TypeConvertKeyBackCompatTest() using (var ch = Env.Start("load")) { using (var fs = File.OpenRead(modelPath)) - modelOld = ML.Model.Load(fs); + modelOld = ML.Model.Load(fs, out var schema); } var outDataOld = modelOld.Transform(dataView); From bc9535fd89e4a376d816bae5537c69604adc3f81 Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Mon, 11 Mar 2019 11:39:44 -0700 Subject: [PATCH 05/18] Fix build after rebase --- .../Scenarios/Api/Estimators/DeserializationTests.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DeserializationTests.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DeserializationTests.cs index c0ae33d498..8743fa3c1d 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DeserializationTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DeserializationTests.cs @@ -28,7 +28,7 @@ private class InputData [Fact] public void LoadModelAndExtractPredictor() { - var ml = new MLContext(seed: 1, conc: 1); + var ml = new MLContext(seed: 1); var file = new MultiFileSource(GetDataPath(TestDatasets.adult.trainFilename)); var loader = ml.Data.CreateTextLoader(hasHeader: true, dataSample: file); var data = loader.Load(file); @@ -57,7 +57,7 @@ public void LoadModelAndExtractPredictor() [Fact] public void SaveAndLoadModelWithLoader() { - var ml = new MLContext(seed: 1, conc: 1); + var ml = new MLContext(seed: 1); var file = new MultiFileSource(GetDataPath(TestDatasets.adult.trainFilename)); var loader = ml.Data.CreateTextLoader(hasHeader: true, dataSample: file); var data = loader.Load(file); From 0a130093a9015d67e54a9d18575e12137db61040 Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Mon, 11 Mar 2019 15:06:01 -0700 Subject: [PATCH 06/18] Add API to create PredictionEngine with input schema --- .../Dynamic/TensorFlow/TextClassification.cs | 6 ++--- .../DataLoadSave/TransformerChain.cs | 2 +- .../DataView/DataViewConstructionUtils.cs | 23 ++++++++++++------- .../Model/PredictionEngineExtensions.cs | 9 +++++--- .../Prediction/PredictionEngine.cs | 6 ++--- .../PredictionFunction.cs | 4 ++-- .../PredictionEngineBench.cs | 6 ++--- ...sticDualCoordinateAscentClassifierBench.cs | 2 +- .../Api/CookbookSamples/CookbookSamples.cs | 2 +- .../CookbookSamplesDynamicApi.cs | 2 +- .../Estimators/DecomposableTrainAndPredict.cs | 2 +- .../Scenarios/Api/Estimators/Extensibility.cs | 2 +- .../Api/Estimators/MultithreadedPrediction.cs | 2 +- .../Api/Estimators/PredictAndMetadata.cs | 2 +- .../Api/Estimators/SimpleTrainAndPredict.cs | 4 ++-- .../Estimators/TrainSaveModelAndPredict.cs | 9 ++++---- .../Scenarios/ClusteringTests.cs | 2 +- .../Scenarios/IrisPlantClassificationTests.cs | 2 +- ...PlantClassificationWithStringLabelTests.cs | 2 +- .../Scenarios/TensorflowTests.cs | 2 +- .../IrisPlantClassificationTests.cs | 2 +- .../TensorflowTests.cs | 22 +++++++++--------- 22 files changed, 63 insertions(+), 52 deletions(-) diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/TensorFlow/TextClassification.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/TensorFlow/TextClassification.cs index b2b5363a8d..7a28c2a463 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/TensorFlow/TextClassification.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/TensorFlow/TextClassification.cs @@ -68,13 +68,13 @@ public static void Example() j.Features = features; }; - var engine = mlContext.Transforms.Text.TokenizeIntoWords("TokenizedWords", "Sentiment_Text") + var model = mlContext.Transforms.Text.TokenizeWords("TokenizedWords", "Sentiment_Text") .Append(mlContext.Transforms.Conversion.MapValue(lookupMap, "Words", "Ids", new ColumnOptions[] { ("VariableLenghtFeatures", "TokenizedWords") })) .Append(mlContext.Transforms.CustomMapping(ResizeFeaturesAction, "Resize")) .Append(tensorFlowModel.ScoreTensorFlowModel(new[] { "Prediction/Softmax" }, new[] { "Features" })) .Append(mlContext.Transforms.CopyColumns(("Prediction", "Prediction/Softmax"))) - .Fit(dataView) - .CreatePredictionEngine(mlContext); + .Fit(dataView); + var engine = mlContext.Model.CreatePredictionEngine(model); // Predict with TensorFlow pipeline. var prediction = engine.Predict(data[0]); diff --git a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs index a5f7bc621e..6339f72546 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs @@ -237,7 +237,7 @@ IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema) /// /// Saving/loading routines for transformer chains. /// - public static class TransformerChain + internal static class TransformerChain { public const string LoaderSignature = "TransformerChain"; diff --git a/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs b/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs index 8ed0dbfa61..6fc6490aa1 100644 --- a/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs @@ -56,30 +56,37 @@ public static StreamingDataView CreateFromEnumerable(IHostEnvironmen return new StreamingDataView(env, data, GetInternalSchemaDefinition(env, schema)); } - private static InternalSchemaDefinition GetInternalSchemaDefinition(IHostEnvironment env, DataViewSchema schema) - where TRow : class + internal static SchemaDefinition GetSchemaDefinition(IHostEnvironment env, DataViewSchema schema) { Contracts.AssertValue(env); env.AssertValue(schema); - var isd = InternalSchemaDefinition.Create(typeof(TRow), SchemaDefinition.Direction.Read); + var schemaDefinition = SchemaDefinition.Create(typeof(TRow), SchemaDefinition.Direction.Read); foreach (var col in schema) { var name = col.Name; - var isdCol = isd.Columns.FirstOrDefault(c => c.ColumnName == name); - if (isdCol == null) - throw env.Except($"Type should contain a member named {isdCol.ColumnName}"); + var schemaDefinitionCol = schemaDefinition.FirstOrDefault(c => c.ColumnName == name); + if (schemaDefinitionCol == null) + throw env.Except($"Type should contain a member named {name}"); var annotations = col.Annotations; if (annotations != null) { foreach (var annotation in annotations.Schema) { var info = Utils.MarshalInvoke(GetAnnotationInfo, annotation.Type.RawType, annotation.Name, annotations); - isdCol.Annotations.Add(annotation.Name, info); + schemaDefinitionCol.Annotations.Add(annotation.Name, info); } } } - return isd; + return schemaDefinition; + } + + private static InternalSchemaDefinition GetInternalSchemaDefinition(IHostEnvironment env, DataViewSchema schema) + where TRow : class + { + Contracts.AssertValue(env); + env.AssertValue(schema); + return InternalSchemaDefinition.Create(typeof(TRow), GetSchemaDefinition(env, schema)); } private static AnnotationInfo GetAnnotationInfo(string kind, DataViewSchema.Annotations annotations) diff --git a/src/Microsoft.ML.Data/Model/PredictionEngineExtensions.cs b/src/Microsoft.ML.Data/Model/PredictionEngineExtensions.cs index 0231eebb10..35d7fbf148 100644 --- a/src/Microsoft.ML.Data/Model/PredictionEngineExtensions.cs +++ b/src/Microsoft.ML.Data/Model/PredictionEngineExtensions.cs @@ -10,7 +10,7 @@ namespace Microsoft.ML /// /// Extension methods to create a prediction engine. /// - public static class PredictionEngineExtensions + internal static class PredictionEngineExtensions { /// /// Create a prediction engine for one-time prediction. @@ -19,12 +19,15 @@ public static class PredictionEngineExtensions /// The class that defines the output data. /// The transformer to use for prediction. /// The environment to use. + /// Whether to throw an exception if a column exists in + /// but the corresponding member doesn't exist in + /// . /// Additional settings of the input schema. /// Additional settings of the output schema. public static PredictionEngine CreatePredictionEngine(this ITransformer transformer, - IHostEnvironment env, SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null) + IHostEnvironment env, bool ignoreMissingColumns = true, SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null) where TSrc : class where TDst : class, new() - => new PredictionEngine(env, transformer, true, inputSchemaDefinition, outputSchemaDefinition); + => new PredictionEngine(env, transformer, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition); } } diff --git a/src/Microsoft.ML.Data/Prediction/PredictionEngine.cs b/src/Microsoft.ML.Data/Prediction/PredictionEngine.cs index 71daa11841..e01f08f8f1 100644 --- a/src/Microsoft.ML.Data/Prediction/PredictionEngine.cs +++ b/src/Microsoft.ML.Data/Prediction/PredictionEngine.cs @@ -131,13 +131,13 @@ private protected PredictionEngineBase(IHostEnvironment env, ITransformer transf var makeMapper = TransformerChecker(env, transformer); env.AssertValue(makeMapper); _inputRow = DataViewConstructionUtils.CreateInputRow(env, inputSchemaDefinition); - PredictionEngineCore(env, _inputRow, makeMapper(_inputRow.Schema), ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition, out _disposer, out _outputRow); + PredictionEngineCore(env, _inputRow, makeMapper(_inputRow.Schema), ignoreMissingColumns, outputSchemaDefinition, out _disposer, out _outputRow); OutputSchema = Transformer.GetOutputSchema(_inputRow.Schema); } [BestFriend] - private protected virtual void PredictionEngineCore(IHostEnvironment env, DataViewConstructionUtils.InputRow inputRow, IRowToRowMapper mapper, bool ignoreMissingColumns, - SchemaDefinition inputSchemaDefinition, SchemaDefinition outputSchemaDefinition, out Action disposer, out IRowReadableAs outputRow) + private protected virtual void PredictionEngineCore(IHostEnvironment env, DataViewConstructionUtils.InputRow inputRow, + IRowToRowMapper mapper, bool ignoreMissingColumns, SchemaDefinition outputSchemaDefinition, out Action disposer, out IRowReadableAs outputRow) { var cursorable = TypedCursorable.Create(env, new EmptyDataView(env, mapper.OutputSchema), ignoreMissingColumns, outputSchemaDefinition); var outputRowLocal = mapper.GetRow(inputRow, mapper.OutputSchema); diff --git a/src/Microsoft.ML.TimeSeries/PredictionFunction.cs b/src/Microsoft.ML.TimeSeries/PredictionFunction.cs index aad6a2ff76..070e0c3d10 100644 --- a/src/Microsoft.ML.TimeSeries/PredictionFunction.cs +++ b/src/Microsoft.ML.TimeSeries/PredictionFunction.cs @@ -165,8 +165,8 @@ private Action CreatePinger(List rows) return pinger; } - private protected override void PredictionEngineCore(IHostEnvironment env, DataViewConstructionUtils.InputRow inputRow, IRowToRowMapper mapper, bool ignoreMissingColumns, - SchemaDefinition inputSchemaDefinition, SchemaDefinition outputSchemaDefinition, out Action disposer, out IRowReadableAs outputRow) + private protected override void PredictionEngineCore(IHostEnvironment env, DataViewConstructionUtils.InputRow inputRow, + IRowToRowMapper mapper, bool ignoreMissingColumns, SchemaDefinition outputSchemaDefinition, out Action disposer, out IRowReadableAs outputRow) { List rows = new List(); DataViewRow outputRowLocal = outputRowLocal = GetStatefulRows(inputRow, mapper, mapper.OutputSchema, rows); diff --git a/test/Microsoft.ML.Benchmarks/PredictionEngineBench.cs b/test/Microsoft.ML.Benchmarks/PredictionEngineBench.cs index a19eb80aec..29aa387ce0 100644 --- a/test/Microsoft.ML.Benchmarks/PredictionEngineBench.cs +++ b/test/Microsoft.ML.Benchmarks/PredictionEngineBench.cs @@ -62,7 +62,7 @@ public void SetupIrisPipeline() var model = pipeline.Fit(data); - _irisModel = model.CreatePredictionEngine(env); + _irisModel = env.Model.CreatePredictionEngine(model); } [GlobalSetup(Target = nameof(MakeSentimentPredictions))] @@ -97,7 +97,7 @@ public void SetupSentimentPipeline() var model = pipeline.Fit(data); - _sentimentModel = model.CreatePredictionEngine(mlContext); + _sentimentModel = mlContext.Model.CreatePredictionEngine(model); } [GlobalSetup(Target = nameof(MakeBreastCancerPredictions))] @@ -131,7 +131,7 @@ public void SetupBreastCancerPipeline() var model = pipeline.Fit(data); - _breastCancerModel = model.CreatePredictionEngine(env); + _breastCancerModel = env.Model.CreatePredictionEngine(model); } [Benchmark] diff --git a/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs b/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs index ed7441bb4a..d86764e397 100644 --- a/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs +++ b/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs @@ -123,7 +123,7 @@ public void TrainSentiment() public void SetupPredictBenchmarks() { _trainedModel = Train(_dataPath); - _predictionEngine = _trainedModel.CreatePredictionEngine(mlContext); + _predictionEngine = mlContext.Model.CreatePredictionEngine(_trainedModel); _consumer.Consume(_predictionEngine.Predict(_example)); // Create text loader. diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs index 8b9f8b7a8b..71ab03fe9b 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs @@ -226,7 +226,7 @@ private void PredictOnIris(ITransformer model) // Make the prediction function object. Note that, on average, this call takes around 200x longer // than one prediction, so you might want to cache and reuse the prediction function, instead of // creating one per prediction. - var predictionFunc = model.CreatePredictionEngine(mlContext); + var predictionFunc = mlContext.Model.CreatePredictionEngine(model); // Obtain the prediction. Remember that 'Predict' is not reentrant. If you want to use multiple threads // for simultaneous prediction, make sure each thread is using its own PredictionFunction. diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs index fbb704cabe..bf077d8cfc 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs @@ -208,7 +208,7 @@ private void PredictOnIris(ITransformer model) // Make the prediction function object. Note that, on average, this call takes around 200x longer // than one prediction, so you might want to cache and reuse the prediction function, instead of // creating one per prediction. - var predictionFunc = model.CreatePredictionEngine(mlContext); + var predictionFunc = mlContext.CreatePredictionEngine(model); // Obtain the prediction. Remember that 'Predict' is not reentrant. If you want to use multiple threads // for simultaneous prediction, make sure each thread is using its own PredictionFunction. diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs index 441732aa39..e8c05bb5bf 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs @@ -36,7 +36,7 @@ void DecomposableTrainAndPredict() .Append(new KeyToValueMappingEstimator(ml, "PredictedLabel")); var model = pipeline.Fit(data).GetModelFor(TransformerScope.Scoring); - var engine = model.CreatePredictionEngine(ml); + var engine = ml.Model.CreatePredictionEngine(model); var testLoader = ml.Data.LoadFromTextFile(dataPath, TestDatasets.irisData.GetLoaderColumns(), separatorChar: ',', hasHeader: true); var testData = ml.Data.CreateEnumerable(testLoader, false); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Extensibility.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Extensibility.cs index e04a0ca6f1..e977884d7e 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Extensibility.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Extensibility.cs @@ -45,7 +45,7 @@ void Extensibility() .Append(new KeyToValueMappingEstimator(ml, "PredictedLabel")); var model = pipeline.Fit(data).GetModelFor(TransformerScope.Scoring); - var engine = model.CreatePredictionEngine(ml); + var engine = ml.Model.CreatePredictionEngine(model); var testLoader = ml.Data.LoadFromTextFile(dataPath, TestDatasets.irisData.GetLoaderColumns(), separatorChar: ','); var testData = ml.Data.CreateEnumerable(testLoader, false); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/MultithreadedPrediction.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/MultithreadedPrediction.cs index bac7c98229..daacc9fcae 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/MultithreadedPrediction.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/MultithreadedPrediction.cs @@ -36,7 +36,7 @@ void MultithreadedPrediction() var model = pipeline.Fit(data); // Create prediction engine and test predictions. - var engine = model.CreatePredictionEngine(ml); + var engine = ml.Model.CreatePredictionEngine(model); // Take a couple examples out of the test data and run predictions on top. var testData = ml.Data.CreateEnumerable( diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/PredictAndMetadata.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/PredictAndMetadata.cs index c293ec6970..75dadbb957 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/PredictAndMetadata.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/PredictAndMetadata.cs @@ -33,7 +33,7 @@ void PredictAndMetadata() new SdcaMulticlassClassificationTrainer.Options { MaximumNumberOfIterations = 100, Shuffle = true, NumberOfThreads = 1, })); var model = pipeline.Fit(data).GetModelFor(TransformerScope.Scoring); - var engine = model.CreatePredictionEngine(ml); + var engine = ml.Model.CreatePredictionEngine(model); var testLoader = ml.Data.LoadFromTextFile(dataPath, TestDatasets.irisData.GetLoaderColumns(), separatorChar: ',', hasHeader: true); var testData = ml.Data.CreateEnumerable(testLoader, false); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs index 0b7c8fd430..2b16c7ed3b 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs @@ -33,7 +33,7 @@ public void SimpleTrainAndPredict() var model = pipeline.Fit(data); // Create prediction engine and test predictions. - var engine = model.CreatePredictionEngine(ml); + var engine = ml.Model.CreatePredictionEngine(model); // Take a couple examples out of the test data and run predictions on top. var testData = ml.Data.CreateEnumerable( @@ -72,7 +72,7 @@ public void SimpleTrainAndPredictSymSGD() var model = pipeline.Fit(data); // Create prediction engine and test predictions. - var engine = model.CreatePredictionEngine(ml); + var engine = ml.Model.CreatePredictionEngine(model); // Take a couple examples out of the test data and run predictions on top. var testData = ml.Data.CreateEnumerable( diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs index f5603f633b..9e6cfe9fce 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs @@ -4,7 +4,7 @@ using System.IO; using System.Linq; -using Microsoft.ML.Data; +using Microsoft.Data.DataView; using Microsoft.ML.RunTests; using Microsoft.ML.Trainers; using Xunit; @@ -37,15 +37,16 @@ public void TrainSaveModelAndPredict() var modelPath = GetOutputPath("temp.zip"); // Save model. using (var file = File.Create(modelPath)) - model.SaveTo(ml, file); + ml.Model.Save(data.Schema, model, file); // Load model. ITransformer loadedModel; + DataViewSchema inputSchema; using (var file = File.OpenRead(modelPath)) - loadedModel = TransformerChain.LoadFrom(ml, file); + loadedModel = ml.Model.Load(file, out inputSchema); // Create prediction engine and test predictions. - var engine = loadedModel.CreatePredictionEngine(ml); + var engine = ml.Model.CreatePredictionEngine(loadedModel, inputSchema); // Take a couple examples out of the test data and run predictions on top. var testData = ml.Data.CreateEnumerable( diff --git a/test/Microsoft.ML.Tests/Scenarios/ClusteringTests.cs b/test/Microsoft.ML.Tests/Scenarios/ClusteringTests.cs index ef95c95b45..2bf15ea1a8 100644 --- a/test/Microsoft.ML.Tests/Scenarios/ClusteringTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/ClusteringTests.cs @@ -70,7 +70,7 @@ public void PredictClusters() // Validate that initial points we pick up as centers of cluster during data generation belong to different clusters. var labels = new HashSet(); - var predictFunction = trainedModel.CreatePredictionEngine(mlContext); + var predictFunction = mlContext.Model.CreatePredictionEngine(trainedModel); for (int i = 0; i < k; i++) { diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs index e813d5c9c7..af5002b050 100644 --- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs @@ -45,7 +45,7 @@ public void TrainAndPredictIrisModelTest() var trainedModel = pipe.Fit(trainData); // Make predictions - var predictFunction = trainedModel.CreatePredictionEngine(mlContext); + var predictFunction = mlContext.Model.CreatePredictionEngine(trainedModel); IrisPrediction prediction = predictFunction.Predict(new IrisData() { SepalLength = 5.1f, diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs index d86b605d1a..b9278c02c6 100644 --- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs @@ -45,7 +45,7 @@ public void TrainAndPredictIrisModelWithStringLabelTest() var trainedModel = pipe.Fit(trainData); // Make predictions - var predictFunction = trainedModel.CreatePredictionEngine(mlContext); + var predictFunction = mlContext.CreatePredictionEngine(trainedModel); IrisPredictionWithStringLabel prediction = predictFunction.Predict(new IrisDataWithStringLabel() { SepalLength = 5.1f, diff --git a/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs b/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs index 5205e99666..db6f57ec99 100644 --- a/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs @@ -48,7 +48,7 @@ public void TensorFlowTransforCifarEndToEndTest() var metrics = mlContext.MulticlassClassification.Evaluate(predictions); Assert.Equal(1, metrics.MicroAccuracy, 2); - var predictFunction = transformer.CreatePredictionEngine(mlContext); + var predictFunction = mlContext.Model.CreatePredictionEngine(transformer); var prediction = predictFunction.Predict(new CifarData() { ImagePath = GetDataPath("images/banana.jpg") diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs index 871eeb209f..1a03391b82 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/IrisPlantClassificationTests.cs @@ -46,7 +46,7 @@ public void TrainAndPredictIrisModelUsingDirectInstantiationTest() var predicted = trainedModel.Transform(testData); var metrics = mlContext.MulticlassClassification.Evaluate(predicted); CompareMetrics(metrics); - var predictFunction = trainedModel.CreatePredictionEngine(mlContext); + var predictFunction = mlContext.Model.CreatePredictionEngine(trainedModel); ComparePredictions(predictFunction); } diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index fee24d2a14..ac68d6a1bd 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -521,7 +521,7 @@ public void TensorFlowTransformMNISTConvTest() var oneSample = GetOneMNISTExample(); - var predictFunction = trainedModel.CreatePredictionEngine(mlContext); + var predictFunction = mlContext.Model.CreatePredictionEngine(trainedModel); var onePrediction = predictFunction.Predict(oneSample); @@ -570,7 +570,7 @@ public void TensorFlowTransformMNISTLRTrainingTest() var metrics = mlContext.MulticlassClassification.Evaluate(predicted, label: "KeyLabel"); Assert.InRange(metrics.MicroAccuracy, expectedMicroAccuracy, 1); Assert.InRange(metrics.MacroAccuracy, expectedMacroAccruacy, 1); - var predictionFunction = trainedModel.CreatePredictionEngine(mlContext); + var predictionFunction = mlContext.Model.CreatePredictionEngine(trainedModel); var oneSample = GetOneMNISTExample(); var onePrediction = predictionFunction.Predict(oneSample); @@ -693,7 +693,7 @@ private void ExecuteTFTransformMNISTConvTrainingTest(bool shuffle, int? shuffleS Assert.InRange(metrics.MacroAccuracy, expectedMacroAccuracy - 0.1, expectedMacroAccuracy + 0.1); // Create prediction function and test prediction - var predictFunction = trainedModel.CreatePredictionEngine(mlContext); + var predictFunction = mlContext.Model.CreatePredictionEngine(trainedModel); var oneSample = GetOneMNISTExample(); @@ -745,7 +745,7 @@ public void TensorFlowTransformMNISTConvSavedModelTest() // An in-memory example. Its label is predicted below. var oneSample = GetOneMNISTExample(); - var predictFunction = trainedModel.CreatePredictionEngine(mlContext); + var predictFunction = mlContext.Model.CreatePredictionEngine(trainedModel); var onePrediction = predictFunction.Predict(oneSample); @@ -998,18 +998,18 @@ public void TensorFlowSentimentClassificationTest() // The first pipeline 'dataPipe' tokenzies the string into words and maps each word to an integer which is an index in the dictionary. // Then this integer vector is retrieved from the pipeline and resized to fixed length. // The second pipeline 'tfEnginePipe' takes the resized integer vector and passes it to TensoFlow and gets the classification scores. - var estimator = mlContext.Transforms.Text.TokenizeIntoWords("TokenizedWords", "Sentiment_Text") + var estimator = mlContext.Transforms.Text.TokenizeWords("TokenizedWords", "Sentiment_Text") .Append(mlContext.Transforms.Conversion.MapValue(lookupMap, "Words", "Ids", new ColumnOptions[] { ("Features", "TokenizedWords") })); - var dataPipe = estimator.Fit(dataView) - .CreatePredictionEngine(mlContext); + var model = estimator.Fit(dataView); + var dataPipe = mlContext.CreatePredictionEngine(model); // For explanation on how was the `sentiment_model` created // c.f. https://github.com/dotnet/machinelearning-testdata/blob/master/Microsoft.ML.TensorFlow.TestModels/sentiment_model/README.md string modelLocation = @"sentiment_model"; - var tfEnginePipe = mlContext.Model.LoadTensorFlowModel(modelLocation).ScoreTensorFlowModel(new[] { "Prediction/Softmax" }, new[] { "Features" }) + var pipelineModel = mlContext.Model.LoadTensorFlowModel(modelLocation).ScoreTensorFlowModel(new[] { "Prediction/Softmax" }, new[] { "Features" }) .Append(mlContext.Transforms.CopyColumns(("Prediction", "Prediction/Softmax"))) - .Fit(dataView) - .CreatePredictionEngine(mlContext); + .Fit(dataView); + var tfEnginePipe = mlContext.CreatePredictionEngine(pipelineModel); var processedData = dataPipe.Predict(data[0]); Array.Resize(ref processedData.Features, 600); @@ -1052,7 +1052,7 @@ public void TensorFlowStringTest() var pipeline = tensorFlowModel.ScoreTensorFlowModel(new[] { "Original_A", "Joined_Splited_Text" }, new[] { "A", "B" }) .Append(mlContext.Transforms.CopyColumns(("AOut", "Original_A"), ("BOut", "Joined_Splited_Text"))); - var transformer = pipeline.Fit(dataview).CreatePredictionEngine(mlContext); + var transformer = mlContext.Model.CreatePredictionEngine(pipeline.Fit(dataview)); var input = new TextInput { From 573e939684d38d02d2cca6f467860774c2d89391 Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Mon, 11 Mar 2019 15:39:48 -0700 Subject: [PATCH 07/18] Address code review comments --- .../Properties/AssemblyInfo.cs | 1 + .../Microsoft.ML.Functional.Tests.csproj | 2 +- .../ModelLoading.cs} | 29 ++++++++++++++++--- 3 files changed, 27 insertions(+), 5 deletions(-) rename test/{Microsoft.ML.Tests/Scenarios/Api/Estimators/DeserializationTests.cs => Microsoft.ML.Functional.Tests/ModelLoading.cs} (79%) diff --git a/src/Microsoft.ML.Core/Properties/AssemblyInfo.cs b/src/Microsoft.ML.Core/Properties/AssemblyInfo.cs index 1d1254a089..851b461424 100644 --- a/src/Microsoft.ML.Core/Properties/AssemblyInfo.cs +++ b/src/Microsoft.ML.Core/Properties/AssemblyInfo.cs @@ -13,6 +13,7 @@ [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.InferenceTesting" + PublicKey.TestValue)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.StaticPipelineTesting" + PublicKey.TestValue)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.OnnxTransformerTest" + PublicKey.TestValue)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Functional.Tests" + PublicKey.TestValue)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.EntryPoints" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Maml" + PublicKey.Value)] diff --git a/test/Microsoft.ML.Functional.Tests/Microsoft.ML.Functional.Tests.csproj b/test/Microsoft.ML.Functional.Tests/Microsoft.ML.Functional.Tests.csproj index fe8f59ce0f..4b7f6b177f 100644 --- a/test/Microsoft.ML.Functional.Tests/Microsoft.ML.Functional.Tests.csproj +++ b/test/Microsoft.ML.Functional.Tests/Microsoft.ML.Functional.Tests.csproj @@ -2,7 +2,7 @@ - false + true false diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DeserializationTests.cs b/test/Microsoft.ML.Functional.Tests/ModelLoading.cs similarity index 79% rename from test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DeserializationTests.cs rename to test/Microsoft.ML.Functional.Tests/ModelLoading.cs index 8743fa3c1d..935418262d 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DeserializationTests.cs +++ b/test/Microsoft.ML.Functional.Tests/ModelLoading.cs @@ -9,17 +9,23 @@ using Microsoft.ML.Calibrators; using Microsoft.ML.Data; using Microsoft.ML.RunTests; +using Microsoft.ML.TestFramework; using Microsoft.ML.Trainers.FastTree; using Xunit; +using Xunit.Abstractions; -namespace Microsoft.ML.Tests.Scenarios.Api +namespace Microsoft.ML.Functional.Tests { - public partial class ApiScenariosTests + public partial class ModelLoadingTests : BaseTestClass { + public ModelLoadingTests(ITestOutputHelper output) : base(output) + { + } + private class InputData { [LoadColumn(0)] - public float Label { get; set; } + public bool Label { get; set; } [LoadColumn(9, 14)] [VectorType(6)] public float[] Features { get; set; } @@ -35,23 +41,38 @@ public void LoadModelAndExtractPredictor() // Pipeline. var pipeline = ml.BinaryClassification.Trainers.GeneralizedAdditiveModels(); - + // Define the same pipeline starting with the loader. + var pipeline1 = loader.Append(ml.BinaryClassification.Trainers.GeneralizedAdditiveModels()); + // Train. var model = pipeline.Fit(data); + var model1 = pipeline1.Fit(file); // Save and reload. string modelPath = GetOutputPath(FullTestName + "-model.zip"); using (var fs = File.Create(modelPath)) ml.Model.Save(data.Schema, model, fs); + string modelPath1 = GetOutputPath(FullTestName + "-model1.zip"); + using (var fs = File.Create(modelPath1)) + ml.Model.Save(model1, fs); ITransformer loadedModel; + IDataLoader loadedModel1; using (var fs = File.OpenRead(modelPath)) loadedModel = ml.Model.Load(fs, out var loadedSchema); + using (var fs = File.OpenRead(modelPath1)) + loadedModel1 = ml.Model.Load(fs); var gam = ((loadedModel as ISingleFeaturePredictionTransformer).Model as CalibratedModelParametersBase).SubModel as BinaryClassificationGamModelParameters; Assert.NotNull(gam); + + gam = (((loadedModel1 as CompositeDataLoader).Transformer.LastTransformer + as ISingleFeaturePredictionTransformer).Model + as CalibratedModelParametersBase).SubModel + as BinaryClassificationGamModelParameters; + Assert.NotNull(gam); } [Fact] From f3fdf89398856fabe3140eca37ee63182d87c60c Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Tue, 12 Mar 2019 09:03:59 -0700 Subject: [PATCH 08/18] Unfriend Functional.Tests --- src/Microsoft.ML.Core/Properties/AssemblyInfo.cs | 1 - test/Microsoft.ML.Functional.Tests/ModelLoading.cs | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/Microsoft.ML.Core/Properties/AssemblyInfo.cs b/src/Microsoft.ML.Core/Properties/AssemblyInfo.cs index 851b461424..1d1254a089 100644 --- a/src/Microsoft.ML.Core/Properties/AssemblyInfo.cs +++ b/src/Microsoft.ML.Core/Properties/AssemblyInfo.cs @@ -13,7 +13,6 @@ [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.InferenceTesting" + PublicKey.TestValue)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.StaticPipelineTesting" + PublicKey.TestValue)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.OnnxTransformerTest" + PublicKey.TestValue)] -[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Functional.Tests" + PublicKey.TestValue)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.EntryPoints" + PublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Maml" + PublicKey.Value)] diff --git a/test/Microsoft.ML.Functional.Tests/ModelLoading.cs b/test/Microsoft.ML.Functional.Tests/ModelLoading.cs index 935418262d..6acc112cc3 100644 --- a/test/Microsoft.ML.Functional.Tests/ModelLoading.cs +++ b/test/Microsoft.ML.Functional.Tests/ModelLoading.cs @@ -106,10 +106,10 @@ public void SaveAndLoadModelWithLoader() // Without deserializing the loader from the model we lose the slot names. data = ml.Data.LoadFromEnumerable(new[] { new InputData() }); data = loadedModelWithoutLoader.Transform(data); - Assert.Null(data.Schema["Features"].Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.SlotNames)); + Assert.True(!data.Schema["Features"].HasSlotNames()); data = loadedModel.Load(file); - Assert.True(data.Schema["Features"].HasSlotNames(data.Schema["Features"].Type.GetValueCount())); + Assert.True(data.Schema["Features"].HasSlotNames()); VBuffer> slotNames = default; data.Schema["Features"].GetSlotNames(ref slotNames); var ageIndex = FindIndex(slotNames.GetValues(), "age"); From 17462daa13fb104999eb3ec041bc84d656e57413 Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Tue, 12 Mar 2019 11:54:50 -0700 Subject: [PATCH 09/18] Add CreatePredictionEngine API back to ModelOperationsCatalog --- .../Model/ModelOperationsCatalog.cs | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs index 0a87cd26ef..dceec103c5 100644 --- a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs +++ b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs @@ -159,5 +159,32 @@ internal ExplainabilityTransforms(ModelOperationsCatalog owner) _env = owner._env; } } + + /// + /// Create a prediction engine for one-time prediction. + /// + /// The class that defines the input data. + /// The class that defines the output data. + /// The transformer to use for prediction. + /// Whether to throw an exception if a column exists in + /// but the corresponding member doesn't exist in + /// . + /// Additional settings of the input schema. + /// Additional settings of the output schema. + public PredictionEngine CreatePredictionEngine(ITransformer transformer, + bool ignoreMissingColumns = true, SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null) + where TSrc : class + where TDst : class, new() + { + return transformer.CreatePredictionEngine(_env, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition); + } + + public PredictionEngine CreatePredictionEngine(ITransformer transformer, DataViewSchema inputSchema) + where TSrc : class + where TDst : class, new() + { + return transformer.CreatePredictionEngine(_env, false, + DataViewConstructionUtils.GetSchemaDefinition(_env, inputSchema)); + } } } From 9177e43c0d9c5050d222be8b9afa872b2956eace Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Wed, 13 Mar 2019 08:55:25 -0700 Subject: [PATCH 10/18] Address code review comments --- .../DataLoadSave/CompositeDataLoader.cs | 17 +++--- .../DataLoadSave/TransformerChain.cs | 4 +- .../Model/ModelOperationsCatalog.cs | 29 +++++++--- .../UnitTests/TestEntryPoints.cs | 2 +- .../ModelLoading.cs | 56 +++++++++++++------ .../Api/CookbookSamples/CookbookSamples.cs | 3 +- .../CookbookSamplesDynamicApi.cs | 4 +- .../MatrixFactorizationTests.cs | 3 +- .../Transformers/ConvertTests.cs | 2 +- 9 files changed, 81 insertions(+), 39 deletions(-) diff --git a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs index f7869de3d6..fcc808406c 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System.IO; using Microsoft.ML; using Microsoft.ML.Data; using Microsoft.ML.Runtime; @@ -19,6 +18,10 @@ namespace Microsoft.ML.Data public sealed class CompositeDataLoader : IDataLoader where TLastTransformer : class, ITransformer { + private const string LoaderDirectory = "Loader"; + private const string LegacyLoaderDirectory = "Reader"; + private const string TransformerDirectory = TransformerChain.LoaderSignature; + /// /// The underlying data loader. /// @@ -39,9 +42,9 @@ public CompositeDataLoader(IDataLoader loader, TransformerChain, SignatureLoadModel>(host, out Loader, ModelOperationsCatalog.LegacyLoaderDirectory)) - ctx.LoadModel, SignatureLoadModel>(host, out Loader, ModelOperationsCatalog.LoaderDirectory); - ctx.LoadModel, SignatureLoadModel>(host, out Transformer, ModelOperationsCatalog.TransformerDirectory); + if (!ctx.LoadModelOrNull, SignatureLoadModel>(host, out Loader, LegacyLoaderDirectory)) + ctx.LoadModel, SignatureLoadModel>(host, out Loader, LoaderDirectory); + ctx.LoadModel, SignatureLoadModel>(host, out Transformer, TransformerDirectory); } private static CompositeDataLoader Create(IHostEnvironment env, ModelLoadContext ctx) @@ -90,11 +93,11 @@ void ICanSaveModel.Save(ModelSaveContext ctx) ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); - ctx.SaveModel(Loader, ModelOperationsCatalog.LoaderDirectory); - ctx.SaveModel(Transformer, ModelOperationsCatalog.TransformerDirectory); + ctx.SaveModel(Loader, LoaderDirectory); + ctx.SaveModel(Transformer, TransformerDirectory); } - internal const string Summary = "A loader that encapsulates a loader and a transformer chain."; + internal const string Summary = "A model loader that encapsulates a data loader and a transformer chain."; internal const string LoaderSignature = "CompositeLoader"; private static VersionInfo GetVersionInfo() diff --git a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs index 6339f72546..6eabf64449 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs @@ -256,9 +256,7 @@ public static ITransformer LoadFrom(IHostEnvironment env, Stream stream) { try { - ModelLoadContext.LoadModelOrNull(env, out var transformerChain, rep, LoaderSignature); - if (transformerChain == null) - ModelLoadContext.LoadModel(env, out transformerChain, rep, $@"Model\{LoaderSignature}"); + ModelLoadContext.LoadModel(env, out var transformerChain, rep, LoaderSignature); return transformerChain; } catch (FormatException ex) diff --git a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs index dceec103c5..374019ac8a 100644 --- a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs +++ b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs @@ -17,10 +17,7 @@ namespace Microsoft.ML /// public sealed class ModelOperationsCatalog : IInternalCatalog { - internal const string LoaderDirectory = "Loader"; - internal const string LegacyLoaderDirectory = "Reader"; - internal const string TransformerDirectory = TransformerChain.LoaderSignature; - internal const string SchemaEntryName = "Schema"; + private const string SchemaEntryName = "Schema"; IHostEnvironment IInternalCatalog.Environment => _env; private readonly IHostEnvironment _env; @@ -44,7 +41,7 @@ public void Save(IDataLoader model, Stream stream) { using (var rep = RepositoryWriter.CreateNew(stream)) { - ModelSaveContext.SaveModel(rep, model, "Model"); + ModelSaveContext.SaveModel(rep, model, null); SaveInputSchema(model.GetOutputSchema(), rep); rep.Commit(); } @@ -89,7 +86,8 @@ private void SaveInputSchema(DataViewSchema inputSchema, RepositoryWriter rep) /// Load the model and its input schema from the stream. /// /// A readable, seekable stream to load from. - /// Will contain the input schema for the model. + /// Will contain the input schema for the model. If the model was saved using older APIs + /// it may not contain an input schema, in this case will be null. /// The loaded model. public ITransformer Load(Stream stream, out DataViewSchema inputSchema) { @@ -130,11 +128,28 @@ public CompositeDataLoader Load(Stream stream) { using (var rep = RepositoryReader.Open(stream)) { - ModelLoadContext.LoadModel, SignatureLoadModel>(_env, out var model, rep, "Model"); + ModelLoadContext.LoadModel, SignatureLoadModel>(_env, out var model, rep, null); return model; } } + /// + /// Load a transformer model and a data loader model from the stream. + /// + /// A readable, seekable stream to load from. + /// The data loader from the model stream. + /// The transformer model from the model stream. + public ITransformer Load(Stream stream, out IDataLoader loader) + { + loader = Load(stream); + if (loader is CompositeDataLoader composite) + { + loader = composite.Loader; + return composite.Transformer; + } + return new TransformerChain(); + } + /// /// Load the model from a file path. /// diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index 8afe9f106b..f9753997d5 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -5645,7 +5645,7 @@ public void LoadEntryPointModel() ITransformer loadedModel; using (var stream = File.OpenRead(modelPath)) { - loadedModel = ml.Model.Load(stream, out var inputSchema); + loadedModel = ml.Model.Load(stream, out DataViewSchema inputSchema); } } diff --git a/test/Microsoft.ML.Functional.Tests/ModelLoading.cs b/test/Microsoft.ML.Functional.Tests/ModelLoading.cs index 6acc112cc3..dc598bc44d 100644 --- a/test/Microsoft.ML.Functional.Tests/ModelLoading.cs +++ b/test/Microsoft.ML.Functional.Tests/ModelLoading.cs @@ -45,30 +45,54 @@ public void LoadModelAndExtractPredictor() var pipeline1 = loader.Append(ml.BinaryClassification.Trainers.GeneralizedAdditiveModels()); // Train. - var model = pipeline.Fit(data); - var model1 = pipeline1.Fit(file); + var transformerModel = pipeline.Fit(data); + var compositeLoaderModel = pipeline1.Fit(file); // Save and reload. - string modelPath = GetOutputPath(FullTestName + "-model.zip"); - using (var fs = File.Create(modelPath)) - ml.Model.Save(data.Schema, model, fs); - string modelPath1 = GetOutputPath(FullTestName + "-model1.zip"); - using (var fs = File.Create(modelPath1)) - ml.Model.Save(model1, fs); + string modelAndSchemaPath = GetOutputPath(FullTestName + "-model-schema.zip"); + using (var fs = File.Create(modelAndSchemaPath)) + ml.Model.Save(data.Schema, transformerModel, fs); + string compositeLoaderModelPath = GetOutputPath(FullTestName + "-composite-model.zip"); + using (var fs = File.Create(compositeLoaderModelPath)) + ml.Model.Save(compositeLoaderModel, fs); + string loaderAndTransformerModelPath = GetOutputPath(FullTestName + "-loader-transformer.zip"); + using (var fs = File.Create(loaderAndTransformerModelPath)) + ml.Model.Save(loader, transformerModel, fs); + + ITransformer loadedTransformerModel; + IDataLoader loadedCompositeLoader; + ITransformer loadedTransformerModel1; + using (var fs = File.OpenRead(modelAndSchemaPath)) + loadedTransformerModel = ml.Model.Load(fs, out DataViewSchema loadedSchema); + using (var fs = File.OpenRead(compositeLoaderModelPath)) + { + // This model can be loaded either as a composite data loader, + // a transformer model + an input schema, or a transformer model + a data loader. + var t = ml.Model.Load(fs, out IDataLoader l); + var t1 = ml.Model.Load(fs, out DataViewSchema s); + loadedCompositeLoader = ml.Model.Load(fs); + } + using (var fs = File.OpenRead(loaderAndTransformerModelPath)) + { + // This model can be loaded either as a composite data loader, + // a transformer model + an input schema, or a transformer model + a data loader. + var t = ml.Model.Load(fs, out DataViewSchema s); + var c = ml.Model.Load(fs); + loadedTransformerModel1 = ml.Model.Load(fs, out IDataLoader l); + } - ITransformer loadedModel; - IDataLoader loadedModel1; - using (var fs = File.OpenRead(modelPath)) - loadedModel = ml.Model.Load(fs, out var loadedSchema); - using (var fs = File.OpenRead(modelPath1)) - loadedModel1 = ml.Model.Load(fs); + var gam = ((loadedTransformerModel as ISingleFeaturePredictionTransformer).Model + as CalibratedModelParametersBase).SubModel + as BinaryClassificationGamModelParameters; + Assert.NotNull(gam); - var gam = ((loadedModel as ISingleFeaturePredictionTransformer).Model + gam = (((loadedCompositeLoader as CompositeDataLoader).Transformer.LastTransformer + as ISingleFeaturePredictionTransformer).Model as CalibratedModelParametersBase).SubModel as BinaryClassificationGamModelParameters; Assert.NotNull(gam); - gam = (((loadedModel1 as CompositeDataLoader).Transformer.LastTransformer + gam = (((loadedTransformerModel1 as TransformerChain).LastTransformer as ISingleFeaturePredictionTransformer).Model as CalibratedModelParametersBase).SubModel as BinaryClassificationGamModelParameters; diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs index 71ab03fe9b..8739dccf62 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs @@ -7,6 +7,7 @@ using System.Collections.Immutable; using System.IO; using System.Linq; +using Microsoft.Data.DataView; using Microsoft.ML; using Microsoft.ML.Data; using Microsoft.ML.RunTests; @@ -157,7 +158,7 @@ private void TrainRegression(string trainDataPath, string testDataPath, string m // When you load the model, it's a 'dynamic' transformer. ITransformer loadedModel; using (var stream = File.OpenRead(modelPath)) - loadedModel = mlContext.Model.Load(stream, out var schema); + loadedModel = mlContext.Model.Load(stream, out DataViewSchema schema); } [Fact] diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs index bf077d8cfc..dab52554fb 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs @@ -131,7 +131,7 @@ private void TrainRegression(string trainDataPath, string testDataPath, string m // When you load the model, it's a 'dynamic' transformer. ITransformer loadedModel; using (var stream = File.OpenRead(modelPath)) - loadedModel = mlContext.Model.Load(stream, out var schema); + loadedModel = mlContext.Model.Load(stream, out DataViewSchema schema); } [Fact] @@ -535,7 +535,7 @@ private static void RunEndToEnd(MLContext mlContext, IDataView trainData, string // Now we can load the model. ITransformer loadedModel; using (var fs = File.OpenRead(modelPath)) - loadedModel = newContext.Model.Load(fs, out var schema); + loadedModel = newContext.Model.Load(fs, out DataViewSchema schema); } public static IDataView PrepareData(MLContext mlContext, IDataView data) diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs index bba3945931..4d2fe7c860 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs @@ -7,6 +7,7 @@ using System.IO; using System.Linq; using System.Runtime.InteropServices; +using Microsoft.Data.DataView; using Microsoft.ML.Data; using Microsoft.ML.RunTests; using Microsoft.ML.TestFramework.Attributes; @@ -505,7 +506,7 @@ public void MatrixFactorizationBackCompat() { using (var fs = File.OpenRead(modelPath)) { - model = ML.Model.Load(fs, out var schema); + model = ML.Model.Load(fs, out DataViewSchema schema); // This model was saved without the input schema. Assert.Null(schema); } diff --git a/test/Microsoft.ML.Tests/Transformers/ConvertTests.cs b/test/Microsoft.ML.Tests/Transformers/ConvertTests.cs index e0167a241e..c8e17036e3 100644 --- a/test/Microsoft.ML.Tests/Transformers/ConvertTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/ConvertTests.cs @@ -261,7 +261,7 @@ public void TypeConvertKeyBackCompatTest() using (var ch = Env.Start("load")) { using (var fs = File.OpenRead(modelPath)) - modelOld = ML.Model.Load(fs, out var schema); + modelOld = ML.Model.Load(fs, out DataViewSchema schema); } var outDataOld = modelOld.Transform(dataView); From cf7338d63a01c70760a4e03e385ef8d95b645e28 Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Wed, 13 Mar 2019 09:12:19 -0700 Subject: [PATCH 11/18] Fix build --- src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs | 2 +- src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs | 2 +- test/Microsoft.ML.Functional.Tests/ModelLoading.cs | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs index fcc808406c..8e849802ca 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs @@ -18,9 +18,9 @@ namespace Microsoft.ML.Data public sealed class CompositeDataLoader : IDataLoader where TLastTransformer : class, ITransformer { + internal const string TransformerDirectory = TransformerChain.LoaderSignature; private const string LoaderDirectory = "Loader"; private const string LegacyLoaderDirectory = "Reader"; - private const string TransformerDirectory = TransformerChain.LoaderSignature; /// /// The underlying data loader. diff --git a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs index 374019ac8a..863d549d0e 100644 --- a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs +++ b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs @@ -66,7 +66,7 @@ public void Save(DataViewSchema inputSchema, ITransformer model, Stream stream) { using (var rep = RepositoryWriter.CreateNew(stream)) { - ModelSaveContext.SaveModel(rep, model, TransformerDirectory); + ModelSaveContext.SaveModel(rep, model, CompositeDataLoader.TransformerDirectory); SaveInputSchema(inputSchema, rep); rep.Commit(); } diff --git a/test/Microsoft.ML.Functional.Tests/ModelLoading.cs b/test/Microsoft.ML.Functional.Tests/ModelLoading.cs index dc598bc44d..defedf463b 100644 --- a/test/Microsoft.ML.Functional.Tests/ModelLoading.cs +++ b/test/Microsoft.ML.Functional.Tests/ModelLoading.cs @@ -40,9 +40,9 @@ public void LoadModelAndExtractPredictor() var data = loader.Load(file); // Pipeline. - var pipeline = ml.BinaryClassification.Trainers.GeneralizedAdditiveModels(); + var pipeline = ml.BinaryClassification.Trainers.Gam(); // Define the same pipeline starting with the loader. - var pipeline1 = loader.Append(ml.BinaryClassification.Trainers.GeneralizedAdditiveModels()); + var pipeline1 = loader.Append(ml.BinaryClassification.Trainers.Gam()); // Train. var transformerModel = pipeline.Fit(data); @@ -108,7 +108,7 @@ public void SaveAndLoadModelWithLoader() var data = loader.Load(file); // Pipeline. - var pipeline = ml.BinaryClassification.Trainers.GeneralizedAdditiveModels(); + var pipeline = ml.BinaryClassification.Trainers.Gam(); // Train. var model = pipeline.Fit(data); From 24e93bb431a869a87a30c12e204c5fe8ea6a8439 Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Wed, 13 Mar 2019 09:24:27 -0700 Subject: [PATCH 12/18] Fix F# tests --- test/Microsoft.ML.FSharp.Tests/SmokeTests.fs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs index 103971663b..292e16f30f 100644 --- a/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs +++ b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs @@ -83,7 +83,7 @@ module SmokeTest1 = let model = pipeline.Fit(data) - let engine = model.CreatePredictionEngine(ml) + let engine = ml.Model.CreatePredictionEngine(model) let predictions = [ SentimentData(SentimentText = "This is a gross exaggeration. Nobody is setting a kangaroo court. There was a simple addition.") @@ -123,7 +123,7 @@ module SmokeTest2 = let model = pipeline.Fit(data) - let engine = model.CreatePredictionEngine(ml) + let engine = ml.Model.CreatePredictionEngine(model) let predictions = [ { SentimentText = "This is a gross exaggeration. Nobody is setting a kangaroo court. There was a simple addition."; Sentiment = false } @@ -160,7 +160,7 @@ module SmokeTest3 = let model = pipeline.Fit(data) - let engine = model.CreatePredictionEngine(ml) + let engine = ml.Model.CreatePredictionEngine(model) let predictions = [ SentimentData(SentimentText = "This is a gross exaggeration. Nobody is setting a kangaroo court. There was a simple addition.".AsMemory()) From 487acc322720e2b0575022b660a8150f09ed9b0f Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Wed, 13 Mar 2019 10:14:32 -0700 Subject: [PATCH 13/18] Remove duplicate CreatePredictionEngine API --- .../Utilities/ComponentCreation.cs | 20 ------------------- .../CookbookSamplesDynamicApi.cs | 2 +- ...PlantClassificationWithStringLabelTests.cs | 2 +- .../TensorflowTests.cs | 4 ++-- 4 files changed, 4 insertions(+), 24 deletions(-) diff --git a/src/Microsoft.ML.Data/Utilities/ComponentCreation.cs b/src/Microsoft.ML.Data/Utilities/ComponentCreation.cs index 7bd8a70bdb..33806d5ea4 100644 --- a/src/Microsoft.ML.Data/Utilities/ComponentCreation.cs +++ b/src/Microsoft.ML.Data/Utilities/ComponentCreation.cs @@ -57,26 +57,6 @@ public static RoleMappedData CreateExamples(this IHostEnvironment env, IDataView return new RoleMappedData(data, label, features, group, weight, name: null, custom: custom); } - /// - /// Create an on-demand prediction engine. - /// - /// The host environment to use. - /// The transformer. - /// Whether to ignore missing columns in the data view. - /// The optional input schema. If null, the schema is inferred from the type. - /// The optional output schema. If null, the schema is inferred from the type. - public static PredictionEngine CreatePredictionEngine(this IHostEnvironment env, ITransformer transformer, - bool ignoreMissingColumns = false, SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null) - where TSrc : class - where TDst : class, new() - { - Contracts.CheckValue(env, nameof(env)); - env.CheckValue(transformer, nameof(transformer)); - env.CheckValueOrNull(inputSchemaDefinition); - env.CheckValueOrNull(outputSchemaDefinition); - return new PredictionEngine(env, transformer, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition); - } - /// /// Load the transforms (but not loader) from the model steram and apply them to the specified data. /// It is acceptable to have no transforms in the model stream: in this case the original diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs index dab52554fb..8a01d69996 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs @@ -208,7 +208,7 @@ private void PredictOnIris(ITransformer model) // Make the prediction function object. Note that, on average, this call takes around 200x longer // than one prediction, so you might want to cache and reuse the prediction function, instead of // creating one per prediction. - var predictionFunc = mlContext.CreatePredictionEngine(model); + var predictionFunc = mlContext.Model.CreatePredictionEngine(model); // Obtain the prediction. Remember that 'Predict' is not reentrant. If you want to use multiple threads // for simultaneous prediction, make sure each thread is using its own PredictionFunction. diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs index b9278c02c6..e6c54d9f80 100644 --- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs @@ -45,7 +45,7 @@ public void TrainAndPredictIrisModelWithStringLabelTest() var trainedModel = pipe.Fit(trainData); // Make predictions - var predictFunction = mlContext.CreatePredictionEngine(trainedModel); + var predictFunction = mlContext.Model.CreatePredictionEngine(trainedModel); IrisPredictionWithStringLabel prediction = predictFunction.Predict(new IrisDataWithStringLabel() { SepalLength = 5.1f, diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index ac68d6a1bd..ce62bb4cd2 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -1001,7 +1001,7 @@ public void TensorFlowSentimentClassificationTest() var estimator = mlContext.Transforms.Text.TokenizeWords("TokenizedWords", "Sentiment_Text") .Append(mlContext.Transforms.Conversion.MapValue(lookupMap, "Words", "Ids", new ColumnOptions[] { ("Features", "TokenizedWords") })); var model = estimator.Fit(dataView); - var dataPipe = mlContext.CreatePredictionEngine(model); + var dataPipe = mlContext.Model.CreatePredictionEngine(model); // For explanation on how was the `sentiment_model` created // c.f. https://github.com/dotnet/machinelearning-testdata/blob/master/Microsoft.ML.TensorFlow.TestModels/sentiment_model/README.md @@ -1009,7 +1009,7 @@ public void TensorFlowSentimentClassificationTest() var pipelineModel = mlContext.Model.LoadTensorFlowModel(modelLocation).ScoreTensorFlowModel(new[] { "Prediction/Softmax" }, new[] { "Features" }) .Append(mlContext.Transforms.CopyColumns(("Prediction", "Prediction/Softmax"))) .Fit(dataView); - var tfEnginePipe = mlContext.CreatePredictionEngine(pipelineModel); + var tfEnginePipe = mlContext.Model.CreatePredictionEngine(pipelineModel); var processedData = dataPipe.Predict(data[0]); Array.Resize(ref processedData.Features, 600); From afa4ba4648b87ffeb0a6b112d4d3691c5a4729ee Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Wed, 13 Mar 2019 10:56:36 -0700 Subject: [PATCH 14/18] Add test for creating an IDataView from a loaded schema --- .../DataView/DataViewConstructionUtils.cs | 2 +- .../Model/ModelOperationsCatalog.cs | 19 ++++++----- .../ModelLoading.cs | 33 +++++++++++++++++++ 3 files changed, 45 insertions(+), 9 deletions(-) diff --git a/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs b/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs index 6fc6490aa1..1ca8392b36 100644 --- a/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs @@ -93,7 +93,7 @@ private static AnnotationInfo GetAnnotationInfo(string kind, DataViewSchema.A { T value = default; annotations.GetValue(kind, ref value); - return new AnnotationInfo(kind, value); + return new AnnotationInfo(kind, value, annotations.Schema[kind].Type); } public static InputRow CreateInputRow(IHostEnvironment env, SchemaDefinition schemaDefinition = null) diff --git a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs index 863d549d0e..0b2e9ea068 100644 --- a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs +++ b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs @@ -32,21 +32,24 @@ internal ModelOperationsCatalog(IHostEnvironment env) Explainability = new ExplainabilityTransforms(this); } - /// - /// Save the model to the stream. - /// - /// The trained model to be saved. - /// A writeable, seekable stream to save to. - public void Save(IDataLoader model, Stream stream) + private void Save(DataViewSchema schema, IDataLoader model, Stream stream) { using (var rep = RepositoryWriter.CreateNew(stream)) { ModelSaveContext.SaveModel(rep, model, null); - SaveInputSchema(model.GetOutputSchema(), rep); + SaveInputSchema(schema, rep); rep.Commit(); } } + /// + /// Save the model to the stream. + /// + /// The trained model to be saved. + /// A writeable, seekable stream to save to. + public void Save(IDataLoader model, Stream stream) + => Save(model.GetOutputSchema(), model, stream); + /// /// Save a transformer model and the loader used to create its input data to the stream. /// @@ -54,7 +57,7 @@ public void Save(IDataLoader model, Stream stream) /// The trained model to be saved /// A writeable, seekable stream to save to. public void Save(IDataLoader loader, ITransformer model, Stream stream) => - Save(new CompositeDataLoader(loader, new TransformerChain(model)), stream); + Save(loader.GetOutputSchema(), new CompositeDataLoader(loader, new TransformerChain(model)), stream); /// /// Save a transformer model and the schema of the data that was used to train it to the stream. diff --git a/test/Microsoft.ML.Functional.Tests/ModelLoading.cs b/test/Microsoft.ML.Functional.Tests/ModelLoading.cs index defedf463b..b56b552472 100644 --- a/test/Microsoft.ML.Functional.Tests/ModelLoading.cs +++ b/test/Microsoft.ML.Functional.Tests/ModelLoading.cs @@ -148,6 +148,39 @@ public void SaveAndLoadModelWithLoader() var ageBinEffects = gamModel.GetBinEffects(ageIndex); } + [Fact] + public void LoadSchemaAndCreateNewData() + { + var ml = new MLContext(seed: 1); + var file = new MultiFileSource(GetDataPath(TestDatasets.adult.trainFilename)); + var loader = ml.Data.CreateTextLoader(hasHeader: true, dataSample: file); + var data = loader.Load(file); + + // Pipeline. + var pipeline = ml.Transforms.Normalize("Features"); + + // Train. + var model = pipeline.Fit(data); + + // Save and reload. + string modelPath = GetOutputPath(FullTestName + "-model.zip"); + using (var fs = File.Create(modelPath)) + ml.Model.Save(loader, model, fs); + + ITransformer loadedModel; + DataViewSchema loadedSchema; + using (var fs = File.OpenRead(modelPath)) + loadedModel = ml.Model.Load(fs, out loadedSchema); + + // Without using the schema from the model we lose the slot names. + data = ml.Data.LoadFromEnumerable(new[] { new InputData() }); + data = loadedModel.Transform(data); + Assert.True(!data.Schema["Features"].HasSlotNames()); + + data = ml.Data.LoadFromEnumerable(new[] { new InputData() }, loadedSchema); + Assert.True(data.Schema["Features"].HasSlotNames()); + } + private int FindIndex(ReadOnlySpan> values, string slotName) { int index = 0; From 9d21951e31078737f4dc9bc594d4d4269ac878b8 Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Thu, 14 Mar 2019 09:15:32 -0700 Subject: [PATCH 15/18] Fix build error after rebase --- .../Dynamic/TensorFlow/TextClassification.cs | 2 +- .../ScenariosWithDirectInstantiation/TensorflowTests.cs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/TensorFlow/TextClassification.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/TensorFlow/TextClassification.cs index 7a28c2a463..30310b755b 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/TensorFlow/TextClassification.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/TensorFlow/TextClassification.cs @@ -68,7 +68,7 @@ public static void Example() j.Features = features; }; - var model = mlContext.Transforms.Text.TokenizeWords("TokenizedWords", "Sentiment_Text") + var model = mlContext.Transforms.Text.TokenizeIntoWords("TokenizedWords", "Sentiment_Text") .Append(mlContext.Transforms.Conversion.MapValue(lookupMap, "Words", "Ids", new ColumnOptions[] { ("VariableLenghtFeatures", "TokenizedWords") })) .Append(mlContext.Transforms.CustomMapping(ResizeFeaturesAction, "Resize")) .Append(tensorFlowModel.ScoreTensorFlowModel(new[] { "Prediction/Softmax" }, new[] { "Features" })) diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index ce62bb4cd2..b8c417b87b 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -998,7 +998,7 @@ public void TensorFlowSentimentClassificationTest() // The first pipeline 'dataPipe' tokenzies the string into words and maps each word to an integer which is an index in the dictionary. // Then this integer vector is retrieved from the pipeline and resized to fixed length. // The second pipeline 'tfEnginePipe' takes the resized integer vector and passes it to TensoFlow and gets the classification scores. - var estimator = mlContext.Transforms.Text.TokenizeWords("TokenizedWords", "Sentiment_Text") + var estimator = mlContext.Transforms.Text.TokenizeIntoWords("TokenizedWords", "Sentiment_Text") .Append(mlContext.Transforms.Conversion.MapValue(lookupMap, "Words", "Ids", new ColumnOptions[] { ("Features", "TokenizedWords") })); var model = estimator.Fit(dataView); var dataPipe = mlContext.Model.CreatePredictionEngine(model); From d625b4ce7614ba87b7a93c7d000fe73eb5f4e0b3 Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Thu, 14 Mar 2019 14:04:48 -0700 Subject: [PATCH 16/18] Add unit tests, and address some code review comments --- .../DataLoadSave/TransformerChain.cs | 81 +++---- .../Model/ModelOperationsCatalog.cs | 83 +++++-- .../UnitTests/TestEntryPoints.cs | 1 - .../ModelLoading.cs | 206 +++++++++++++++--- .../DataPipe/TestDataPipeBase.cs | 2 +- test/Microsoft.ML.Tests/ImagesTests.cs | 6 +- .../Api/CookbookSamples/CookbookSamples.cs | 2 +- .../CookbookSamplesDynamicApi.cs | 4 +- .../Estimators/TrainSaveModelAndPredict.cs | 2 +- .../Transformers/CopyColumnEstimatorTests.cs | 4 +- .../Transformers/SelectColumnsTests.cs | 9 +- .../Transformers/ValueMappingTests.cs | 5 +- .../TimeSeriesDirectApi.cs | 7 +- 13 files changed, 297 insertions(+), 115 deletions(-) diff --git a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs index 6eabf64449..f9f7a1f413 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs @@ -250,54 +250,41 @@ private static TransformerChain Create(IHostEnvironment env, Model public static void SaveTo(this ITransformer transformer, IHostEnvironment env, Stream outputStream) => new TransformerChain(transformer).SaveTo(env, outputStream); - public static ITransformer LoadFrom(IHostEnvironment env, Stream stream) + public static ITransformer LoadFromLegacy(IHostEnvironment env, Stream stream) { - using (var rep = RepositoryReader.Open(stream, env)) - { - try - { - ModelLoadContext.LoadModel(env, out var transformerChain, rep, LoaderSignature); - return transformerChain; - } - catch (FormatException ex) - { - if (!ex.IsMarked()) - throw; - var chain = ModelFileUtils.LoadPipeline(env, stream, new MultiFileSource(null), extractInnerPipe: false); - TransformerChain transformChain = (chain as LegacyCompositeDataLoader).GetTransformer(); - var predictor = ModelFileUtils.LoadPredictorOrNull(env, stream); - if (predictor == null) - return transformChain; - var roles = ModelFileUtils.LoadRoleMappingsOrNull(env, stream); - env.CheckDecode(roles != null, "Predictor model must contain role mappings"); - var roleMappings = roles.ToArray(); - - ITransformer pred = null; - if (predictor.PredictionKind == PredictionKind.BinaryClassification) - pred = new BinaryPredictionTransformer>(env, predictor as IPredictorProducing, chain.Schema, - roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value); - else if (predictor.PredictionKind == PredictionKind.MulticlassClassification) - pred = new MulticlassPredictionTransformer>>(env, - predictor as IPredictorProducing>, chain.Schema, - roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value, - roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Label.Value).First().Value); - else if (predictor.PredictionKind == PredictionKind.Clustering) - pred = new ClusteringPredictionTransformer>>(env, predictor as IPredictorProducing>, chain.Schema, - roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value); - else if (predictor.PredictionKind == PredictionKind.Regression) - pred = new RegressionPredictionTransformer>(env, predictor as IPredictorProducing, chain.Schema, - roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value); - else if (predictor.PredictionKind == PredictionKind.AnomalyDetection) - pred = new AnomalyPredictionTransformer>(env, predictor as IPredictorProducing, chain.Schema, - roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value); - else if (predictor.PredictionKind == PredictionKind.Ranking) - pred = new RankingPredictionTransformer>(env, predictor as IPredictorProducing, chain.Schema, - roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value); - else - throw env.Except("Don't know how to map prediction kind {0}", predictor.PredictionKind); - return transformChain.Append(pred); - } - } + var chain = ModelFileUtils.LoadPipeline(env, stream, new MultiFileSource(null), extractInnerPipe: false); + TransformerChain transformChain = (chain as LegacyCompositeDataLoader).GetTransformer(); + var predictor = ModelFileUtils.LoadPredictorOrNull(env, stream); + if (predictor == null) + return transformChain; + var roles = ModelFileUtils.LoadRoleMappingsOrNull(env, stream); + env.CheckDecode(roles != null, "Predictor model must contain role mappings"); + var roleMappings = roles.ToArray(); + + ITransformer pred = null; + if (predictor.PredictionKind == PredictionKind.BinaryClassification) + pred = new BinaryPredictionTransformer>(env, predictor as IPredictorProducing, chain.Schema, + roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value); + else if (predictor.PredictionKind == PredictionKind.MulticlassClassification) + pred = new MulticlassPredictionTransformer>>(env, + predictor as IPredictorProducing>, chain.Schema, + roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value, + roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Label.Value).First().Value); + else if (predictor.PredictionKind == PredictionKind.Clustering) + pred = new ClusteringPredictionTransformer>>(env, predictor as IPredictorProducing>, chain.Schema, + roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value); + else if (predictor.PredictionKind == PredictionKind.Regression) + pred = new RegressionPredictionTransformer>(env, predictor as IPredictorProducing, chain.Schema, + roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value); + else if (predictor.PredictionKind == PredictionKind.AnomalyDetection) + pred = new AnomalyPredictionTransformer>(env, predictor as IPredictorProducing, chain.Schema, + roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value); + else if (predictor.PredictionKind == PredictionKind.Ranking) + pred = new RankingPredictionTransformer>(env, predictor as IPredictorProducing, chain.Schema, + roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value); + else + throw env.Except("Don't know how to map prediction kind {0}", predictor.PredictionKind); + return transformChain.Append(pred); } } } diff --git a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs index 0b2e9ea068..251577b945 100644 --- a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs +++ b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs @@ -4,6 +4,7 @@ using System; using System.IO; +using System.Linq; using Microsoft.Data.DataView; using Microsoft.ML.Data; using Microsoft.ML.Data.IO; @@ -32,24 +33,23 @@ internal ModelOperationsCatalog(IHostEnvironment env) Explainability = new ExplainabilityTransforms(this); } - private void Save(DataViewSchema schema, IDataLoader model, Stream stream) + /// + /// Save the model to the stream. + /// + /// The trained model to be saved. + /// A writeable, seekable stream to save to. + public void Save(IDataLoader model, Stream stream) { + _env.CheckValue(model, nameof(model)); + _env.CheckValue(stream, nameof(stream)); + using (var rep = RepositoryWriter.CreateNew(stream)) { ModelSaveContext.SaveModel(rep, model, null); - SaveInputSchema(schema, rep); rep.Commit(); } } - /// - /// Save the model to the stream. - /// - /// The trained model to be saved. - /// A writeable, seekable stream to save to. - public void Save(IDataLoader model, Stream stream) - => Save(model.GetOutputSchema(), model, stream); - /// /// Save a transformer model and the loader used to create its input data to the stream. /// @@ -57,16 +57,20 @@ public void Save(IDataLoader model, Stream stream) /// The trained model to be saved /// A writeable, seekable stream to save to. public void Save(IDataLoader loader, ITransformer model, Stream stream) => - Save(loader.GetOutputSchema(), new CompositeDataLoader(loader, new TransformerChain(model)), stream); + Save(new CompositeDataLoader(loader, new TransformerChain(model)), stream); /// /// Save a transformer model and the schema of the data that was used to train it to the stream. /// - /// The schema of the input to the transformer. /// The trained model to be saved. + /// The schema of the input to the transformer. This can be null. /// A writeable, seekable stream to save to. - public void Save(DataViewSchema inputSchema, ITransformer model, Stream stream) + public void Save(ITransformer model, DataViewSchema inputSchema, Stream stream) { + _env.CheckValue(model, nameof(model)); + _env.CheckValueOrNull(inputSchema); + _env.CheckValue(stream, nameof(stream)); + using (var rep = RepositoryWriter.CreateNew(stream)) { ModelSaveContext.SaveModel(rep, model, CompositeDataLoader.TransformerDirectory); @@ -77,6 +81,12 @@ public void Save(DataViewSchema inputSchema, ITransformer model, Stream stream) private void SaveInputSchema(DataViewSchema inputSchema, RepositoryWriter rep) { + _env.AssertValueOrNull(inputSchema); + _env.AssertValue(rep); + + if (inputSchema == null) + return; + using (var ch = _env.Start("Saving Schema")) { var entry = rep.CreateEntry(SchemaEntryName); @@ -94,6 +104,8 @@ private void SaveInputSchema(DataViewSchema inputSchema, RepositoryWriter rep) /// The loaded model. public ITransformer Load(Stream stream, out DataViewSchema inputSchema) { + _env.CheckValue(stream, nameof(stream)); + using (var rep = RepositoryReader.Open(stream, _env)) { var entry = rep.OpenEntryOrNull(SchemaEntryName); @@ -101,23 +113,41 @@ public ITransformer Load(Stream stream, out DataViewSchema inputSchema) { var loader = new BinaryLoader(_env, new BinaryLoader.Arguments(), entry.Stream); inputSchema = loader.Schema; + ModelLoadContext.LoadModel(_env, out var transformerChain, rep, + CompositeDataLoader.TransformerDirectory); + return transformerChain; } - else + + ModelLoadContext.LoadModelOrNull, SignatureLoadModel>(_env, out var dataLoader, rep, null); + if (dataLoader == null) { + // Try to see if the model was saved without a loader or a schema. + if (ModelLoadContext.LoadModelOrNull(_env, out var transformerChain, rep, + CompositeDataLoader.TransformerDirectory)) + { + inputSchema = null; + return transformerChain; + } + // Try to load from legacy model format. try { var loader = ModelFileUtils.LoadLoader(_env, rep, new MultiFileSource(null), false); inputSchema = loader.Schema; + return TransformerChain.LoadFromLegacy(_env, stream); } catch (Exception ex) { - if (!ex.IsMarked()) - throw; - inputSchema = null; + throw _env.Except(ex, "Could not load legacy format model"); } } - return TransformerChain.LoadFrom(_env, stream); + if (dataLoader is CompositeDataLoader composite) + { + inputSchema = composite.Loader.GetOutputSchema(); + return composite.Transformer; + } + inputSchema = dataLoader.GetOutputSchema(); + return new TransformerChain(); } } @@ -127,12 +157,21 @@ public ITransformer Load(Stream stream, out DataViewSchema inputSchema) /// A readable, seekable stream to load from. /// A model of type containing the loader /// and the transformer chain. - public CompositeDataLoader Load(Stream stream) + public IDataLoader Load(Stream stream) { + _env.CheckValue(stream, nameof(stream)); + using (var rep = RepositoryReader.Open(stream)) { - ModelLoadContext.LoadModel, SignatureLoadModel>(_env, out var model, rep, null); - return model; + try + { + ModelLoadContext.LoadModel, SignatureLoadModel>(_env, out var model, rep, null); + return model; + } + catch (Exception ex) + { + throw _env.Except(ex, "Model does not contain an IDataLoader"); + } } } @@ -144,6 +183,8 @@ public CompositeDataLoader Load(Stream stream) /// The transformer model from the model stream. public ITransformer Load(Stream stream, out IDataLoader loader) { + _env.CheckValue(stream, nameof(stream)); + loader = Load(stream); if (loader is CompositeDataLoader composite) { diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index f9753997d5..c1183825ab 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -5647,7 +5647,6 @@ public void LoadEntryPointModel() { loadedModel = ml.Model.Load(stream, out DataViewSchema inputSchema); } - } } } diff --git a/test/Microsoft.ML.Functional.Tests/ModelLoading.cs b/test/Microsoft.ML.Functional.Tests/ModelLoading.cs index b56b552472..cd719cfa03 100644 --- a/test/Microsoft.ML.Functional.Tests/ModelLoading.cs +++ b/test/Microsoft.ML.Functional.Tests/ModelLoading.cs @@ -11,6 +11,7 @@ using Microsoft.ML.RunTests; using Microsoft.ML.TestFramework; using Microsoft.ML.Trainers.FastTree; +using Microsoft.ML.Transforms; using Xunit; using Xunit.Abstractions; @@ -18,10 +19,20 @@ namespace Microsoft.ML.Functional.Tests { public partial class ModelLoadingTests : BaseTestClass { + private MLContext _ml; + public ModelLoadingTests(ITestOutputHelper output) : base(output) { } + protected override void Initialize() + { + base.Initialize(); + + _ml = new MLContext(42); + _ml.AddStandardComponents(); + } + private class InputData { [LoadColumn(0)] @@ -34,15 +45,14 @@ private class InputData [Fact] public void LoadModelAndExtractPredictor() { - var ml = new MLContext(seed: 1); var file = new MultiFileSource(GetDataPath(TestDatasets.adult.trainFilename)); - var loader = ml.Data.CreateTextLoader(hasHeader: true, dataSample: file); + var loader = _ml.Data.CreateTextLoader(hasHeader: true, dataSample: file); var data = loader.Load(file); // Pipeline. - var pipeline = ml.BinaryClassification.Trainers.Gam(); + var pipeline = _ml.BinaryClassification.Trainers.Gam(); // Define the same pipeline starting with the loader. - var pipeline1 = loader.Append(ml.BinaryClassification.Trainers.Gam()); + var pipeline1 = loader.Append(_ml.BinaryClassification.Trainers.Gam()); // Train. var transformerModel = pipeline.Fit(data); @@ -51,34 +61,34 @@ public void LoadModelAndExtractPredictor() // Save and reload. string modelAndSchemaPath = GetOutputPath(FullTestName + "-model-schema.zip"); using (var fs = File.Create(modelAndSchemaPath)) - ml.Model.Save(data.Schema, transformerModel, fs); + _ml.Model.Save(transformerModel, data.Schema, fs); string compositeLoaderModelPath = GetOutputPath(FullTestName + "-composite-model.zip"); using (var fs = File.Create(compositeLoaderModelPath)) - ml.Model.Save(compositeLoaderModel, fs); + _ml.Model.Save(compositeLoaderModel, fs); string loaderAndTransformerModelPath = GetOutputPath(FullTestName + "-loader-transformer.zip"); using (var fs = File.Create(loaderAndTransformerModelPath)) - ml.Model.Save(loader, transformerModel, fs); + _ml.Model.Save(loader, transformerModel, fs); ITransformer loadedTransformerModel; IDataLoader loadedCompositeLoader; ITransformer loadedTransformerModel1; using (var fs = File.OpenRead(modelAndSchemaPath)) - loadedTransformerModel = ml.Model.Load(fs, out DataViewSchema loadedSchema); + loadedTransformerModel = _ml.Model.Load(fs, out DataViewSchema loadedSchema); using (var fs = File.OpenRead(compositeLoaderModelPath)) { // This model can be loaded either as a composite data loader, // a transformer model + an input schema, or a transformer model + a data loader. - var t = ml.Model.Load(fs, out IDataLoader l); - var t1 = ml.Model.Load(fs, out DataViewSchema s); - loadedCompositeLoader = ml.Model.Load(fs); + var t = _ml.Model.Load(fs, out IDataLoader l); + var t1 = _ml.Model.Load(fs, out DataViewSchema s); + loadedCompositeLoader = _ml.Model.Load(fs); } using (var fs = File.OpenRead(loaderAndTransformerModelPath)) { // This model can be loaded either as a composite data loader, // a transformer model + an input schema, or a transformer model + a data loader. - var t = ml.Model.Load(fs, out DataViewSchema s); - var c = ml.Model.Load(fs); - loadedTransformerModel1 = ml.Model.Load(fs, out IDataLoader l); + var t = _ml.Model.Load(fs, out DataViewSchema s); + var c = _ml.Model.Load(fs); + loadedTransformerModel1 = _ml.Model.Load(fs, out IDataLoader l); } var gam = ((loadedTransformerModel as ISingleFeaturePredictionTransformer).Model @@ -102,13 +112,12 @@ public void LoadModelAndExtractPredictor() [Fact] public void SaveAndLoadModelWithLoader() { - var ml = new MLContext(seed: 1); var file = new MultiFileSource(GetDataPath(TestDatasets.adult.trainFilename)); - var loader = ml.Data.CreateTextLoader(hasHeader: true, dataSample: file); + var loader = _ml.Data.CreateTextLoader(hasHeader: true, dataSample: file); var data = loader.Load(file); // Pipeline. - var pipeline = ml.BinaryClassification.Trainers.Gam(); + var pipeline = _ml.BinaryClassification.Trainers.Gam(); // Train. var model = pipeline.Fit(data); @@ -116,19 +125,19 @@ public void SaveAndLoadModelWithLoader() // Save and reload. string modelPath = GetOutputPath(FullTestName + "-model.zip"); using (var fs = File.Create(modelPath)) - ml.Model.Save(loader, model, fs); + _ml.Model.Save(loader, model, fs); IDataLoader loadedModel; ITransformer loadedModelWithoutLoader; DataViewSchema loadedSchema; using (var fs = File.OpenRead(modelPath)) { - loadedModel = ml.Model.Load(fs); - loadedModelWithoutLoader = ml.Model.Load(fs, out loadedSchema); + loadedModel = _ml.Model.Load(fs); + loadedModelWithoutLoader = _ml.Model.Load(fs, out loadedSchema); } // Without deserializing the loader from the model we lose the slot names. - data = ml.Data.LoadFromEnumerable(new[] { new InputData() }); + data = _ml.Data.LoadFromEnumerable(new[] { new InputData() }); data = loadedModelWithoutLoader.Transform(data); Assert.True(!data.Schema["Features"].HasSlotNames()); @@ -151,13 +160,12 @@ public void SaveAndLoadModelWithLoader() [Fact] public void LoadSchemaAndCreateNewData() { - var ml = new MLContext(seed: 1); var file = new MultiFileSource(GetDataPath(TestDatasets.adult.trainFilename)); - var loader = ml.Data.CreateTextLoader(hasHeader: true, dataSample: file); + var loader = _ml.Data.CreateTextLoader(hasHeader: true, dataSample: file); var data = loader.Load(file); // Pipeline. - var pipeline = ml.Transforms.Normalize("Features"); + var pipeline = _ml.Transforms.Normalize("Features"); // Train. var model = pipeline.Fit(data); @@ -165,22 +173,164 @@ public void LoadSchemaAndCreateNewData() // Save and reload. string modelPath = GetOutputPath(FullTestName + "-model.zip"); using (var fs = File.Create(modelPath)) - ml.Model.Save(loader, model, fs); + _ml.Model.Save(loader, model, fs); ITransformer loadedModel; DataViewSchema loadedSchema; using (var fs = File.OpenRead(modelPath)) - loadedModel = ml.Model.Load(fs, out loadedSchema); + loadedModel = _ml.Model.Load(fs, out loadedSchema); // Without using the schema from the model we lose the slot names. - data = ml.Data.LoadFromEnumerable(new[] { new InputData() }); + data = _ml.Data.LoadFromEnumerable(new[] { new InputData() }); data = loadedModel.Transform(data); Assert.True(!data.Schema["Features"].HasSlotNames()); - data = ml.Data.LoadFromEnumerable(new[] { new InputData() }, loadedSchema); + data = _ml.Data.LoadFromEnumerable(new[] { new InputData() }, loadedSchema); Assert.True(data.Schema["Features"].HasSlotNames()); } + [Fact] + public void SaveTextLoaderAndLoad() + { + var file = new MultiFileSource(GetDataPath(TestDatasets.adult.trainFilename)); + var loader = _ml.Data.CreateTextLoader(hasHeader: true, dataSample: file); + + string modelPath = GetOutputPath(FullTestName + "-model.zip"); + using (var fs = File.Create(modelPath)) + _ml.Model.Save(loader, fs); + + Load(modelPath, out var loadedWithSchema, out var loadedSchema, out var loadedLoader, + out var loadedWithLoader, out var loadedLoaderWithTransformer); + Assert.True(loadedWithSchema is TransformerChain); + Assert.False((loadedWithSchema as TransformerChain).Any()); + Assert.True(loadedSchema.Count == 2 && + loadedSchema.GetColumnOrNull("Label") != null + && loadedSchema.GetColumnOrNull("Features") != null + && loadedSchema["Features"].HasSlotNames()); + Assert.True(loadedLoader is TextLoader); + Assert.True(loadedWithLoader is TransformerChain); + Assert.False((loadedWithLoader as TransformerChain).Any()); + Assert.True(loadedLoaderWithTransformer is TextLoader); + var schema = loadedLoaderWithTransformer.GetOutputSchema(); + Assert.True(schema.Count == 2 && + schema.GetColumnOrNull("Label") != null + && schema.GetColumnOrNull("Features") != null + && schema["Features"].HasSlotNames()); + } + + [Fact] + public void SaveCompositeLoaderAndLoad() + { + var file = new MultiFileSource(GetDataPath(TestDatasets.adult.trainFilename)); + var loader = _ml.Data.CreateTextLoader(hasHeader: true, dataSample: file); + var composite = loader.Append(_ml.Transforms.Normalize("Features")); + var model = composite.Fit(file); + + string modelPath = GetOutputPath(FullTestName + "-model.zip"); + using (var fs = File.Create(modelPath)) + _ml.Model.Save(model, fs); + + Load(modelPath, out var loadedWithSchema, out var loadedSchema, out var loadedLoader, + out var loadedWithLoader, out var loadedLoaderWithTransformer); + Assert.True(loadedWithSchema is TransformerChain); + Assert.True((loadedWithSchema as TransformerChain).Count() == 1); + Assert.True(loadedSchema.Count == 2 && + loadedSchema.GetColumnOrNull("Label") != null + && loadedSchema.GetColumnOrNull("Features") != null + && loadedSchema["Features"].HasSlotNames()); + Assert.True(loadedLoader is CompositeDataLoader); + Assert.True(loadedWithLoader is TransformerChain); + Assert.True((loadedWithLoader as TransformerChain).Count() == 1); + Assert.True(loadedLoaderWithTransformer is TextLoader); + var schema = loadedLoaderWithTransformer.GetOutputSchema(); + Assert.True(schema.Count == 2 && + schema.GetColumnOrNull("Label") != null + && schema.GetColumnOrNull("Features") != null + && schema["Features"].HasSlotNames()); + } + + [Fact] + public void SaveLoaderAndTransformerAndLoad() + { + var file = new MultiFileSource(GetDataPath(TestDatasets.adult.trainFilename)); + var loader = _ml.Data.CreateTextLoader(hasHeader: true, dataSample: file); + var estimator = _ml.Transforms.Normalize("Features"); + var model = estimator.Fit(loader.Load(file)); + + string modelPath = GetOutputPath(FullTestName + "-model.zip"); + using (var fs = File.Create(modelPath)) + _ml.Model.Save(loader, model, fs); + + Load(modelPath, out var loadedWithSchema, out var loadedSchema, out var loadedLoader, + out var loadedWithLoader, out var loadedLoaderWithTransformer); + Assert.True(loadedWithSchema is TransformerChain); + Assert.True((loadedWithSchema as TransformerChain).Count() == 1); + Assert.True(loadedSchema.Count == 2 && + loadedSchema.GetColumnOrNull("Label") != null + && loadedSchema.GetColumnOrNull("Features") != null + && loadedSchema["Features"].HasSlotNames()); + Assert.True(loadedLoader is CompositeDataLoader); + Assert.True(loadedWithLoader is TransformerChain); + Assert.True((loadedWithLoader as TransformerChain).Count() == 1); + Assert.True(loadedLoaderWithTransformer is TextLoader); + var schema = loadedLoaderWithTransformer.GetOutputSchema(); + Assert.True(schema.Count == 2 && + schema.GetColumnOrNull("Label") != null + && schema.GetColumnOrNull("Features") != null + && schema["Features"].HasSlotNames()); + } + + [Fact] + public void SaveTransformerAndSchemaAndLoad() + { + var file = new MultiFileSource(GetDataPath(TestDatasets.adult.trainFilename)); + var loader = _ml.Data.CreateTextLoader(hasHeader: true, dataSample: file); + var estimator = _ml.Transforms.Normalize("Features"); + var model = estimator.Fit(loader.Load(file)); + + string modelPath = GetOutputPath(FullTestName + "-model.zip"); + using (var fs = File.Create(modelPath)) + _ml.Model.Save(model, loader.GetOutputSchema(), fs); + + Load(modelPath, out var loadedWithSchema, out var loadedSchema, out var loadedLoader, + out var loadedWithLoader, out var loadedLoaderWithTransformer); + Assert.True(loadedWithSchema is NormalizingTransformer); + Assert.True(loadedSchema.Count == 2 && + loadedSchema.GetColumnOrNull("Label") != null + && loadedSchema.GetColumnOrNull("Features") != null + && loadedSchema["Features"].HasSlotNames()); + Assert.Null(loadedLoader); + Assert.Null(loadedWithLoader); + Assert.Null(loadedLoaderWithTransformer); + } + + private void Load(string filename, out ITransformer loadedWithSchema, out DataViewSchema loadedSchema, + out IDataLoader loadedLoader, out ITransformer loadedWithLoader, + out IDataLoader loadedLoaderWithTransformer) + { + using (var fs = File.OpenRead(filename)) + { + try + { + loadedLoader = _ml.Model.Load(fs); + } + catch (Exception) + { + loadedLoader = null; + } + loadedWithSchema = _ml.Model.Load(fs, out loadedSchema); + try + { + loadedWithLoader = _ml.Model.Load(fs, out loadedLoaderWithTransformer); + } + catch (Exception) + { + loadedWithLoader = null; + loadedLoaderWithTransformer = null; + } + } + } + private int FindIndex(ReadOnlySpan> values, string slotName) { int index = 0; diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs index 166c5e8b0a..31c7d02935 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs @@ -80,7 +80,7 @@ protected void TestEstimatorCore(IEstimator estimator, // Save and reload. string modelPath = GetOutputPath(FullTestName + "-model.zip"); using (var fs = File.Create(modelPath)) - ML.Model.Save(validFitInput.Schema, transformer, fs); + ML.Model.Save(transformer, validFitInput.Schema, fs); ITransformer loadedTransformer; DataViewSchema loadedInputSchema; diff --git a/test/Microsoft.ML.Tests/ImagesTests.cs b/test/Microsoft.ML.Tests/ImagesTests.cs index 559e4a6f05..b38d574326 100644 --- a/test/Microsoft.ML.Tests/ImagesTests.cs +++ b/test/Microsoft.ML.Tests/ImagesTests.cs @@ -80,8 +80,10 @@ public void TestEstimatorSaveLoad() using (var file = new SimpleFileHandle(env, tempPath, true, true)) { using (var fs = file.CreateWriteStream()) - model.SaveTo(env, fs); - var model2 = TransformerChain.LoadFrom(env, file.OpenReadStream()); + ML.Model.Save(model, null, fs); + ITransformer model2; + using (var fs = file.OpenReadStream()) + model2 = ML.Model.Load(fs, out DataViewSchema schema); var transformerChain = model2 as TransformerChain; Assert.NotNull(transformerChain); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs index 8739dccf62..cf8c900c54 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs @@ -150,7 +150,7 @@ private void TrainRegression(string trainDataPath, string testDataPath, string m using (var stream = File.Create(modelPath)) { // Saving and loading happens to 'dynamic' models, so the static typing is lost in the process. - mlContext.Model.Save(trainData.AsDynamic.Schema, model.AsDynamic, stream); + mlContext.Model.Save(model.AsDynamic, trainData.AsDynamic.Schema, stream); } // Potentially, the lines below can be in a different process altogether. diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs index 8a01d69996..1ebbb16729 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs @@ -123,7 +123,7 @@ private void TrainRegression(string trainDataPath, string testDataPath, string m using (var stream = File.Create(modelPath)) { // Saving and loading happens to 'dynamic' models. - mlContext.Model.Save(trainData.Schema, model, stream); + mlContext.Model.Save(model, trainData.Schema, stream); } // Potentially, the lines below can be in a different process altogether. @@ -523,7 +523,7 @@ private static void RunEndToEnd(MLContext mlContext, IDataView trainData, string // Save the model. using (var fs = File.Create(modelPath)) - mlContext.Model.Save(cachedTrainData.Schema, model, fs); + mlContext.Model.Save(model, cachedTrainData.Schema, fs); // Now pretend we are in a different process. var newContext = new MLContext(); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs index 9e6cfe9fce..9e84f775b8 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs @@ -37,7 +37,7 @@ public void TrainSaveModelAndPredict() var modelPath = GetOutputPath("temp.zip"); // Save model. using (var file = File.Create(modelPath)) - ml.Model.Save(data.Schema, model, file); + ml.Model.Save(model, data.Schema, file); // Load model. ITransformer loadedModel; diff --git a/test/Microsoft.ML.Tests/Transformers/CopyColumnEstimatorTests.cs b/test/Microsoft.ML.Tests/Transformers/CopyColumnEstimatorTests.cs index 048b363f8e..bcf85537bc 100644 --- a/test/Microsoft.ML.Tests/Transformers/CopyColumnEstimatorTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/CopyColumnEstimatorTests.cs @@ -93,9 +93,9 @@ void TestSavingAndLoading() var transformer = est.Fit(dataView); using (var ms = new MemoryStream()) { - transformer.SaveTo(env, ms); + env.Model.Save(transformer, null, ms); ms.Position = 0; - var loadedTransformer = TransformerChain.LoadFrom(env, ms); + var loadedTransformer = env.Model.Load(ms, out DataViewSchema schema); var result = loadedTransformer.Transform(dataView); ValidateCopyColumnTransformer(result); } diff --git a/test/Microsoft.ML.Tests/Transformers/SelectColumnsTests.cs b/test/Microsoft.ML.Tests/Transformers/SelectColumnsTests.cs index 6ea6675a10..a6c46e8dc5 100644 --- a/test/Microsoft.ML.Tests/Transformers/SelectColumnsTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/SelectColumnsTests.cs @@ -4,6 +4,7 @@ using System; using System.IO; +using Microsoft.Data.DataView; using Microsoft.ML.Data; using Microsoft.ML.Model; using Microsoft.ML.RunTests; @@ -180,9 +181,9 @@ void TestSelectSavingAndLoading() var transformer = est.Fit(dataView); using (var ms = new MemoryStream()) { - transformer.SaveTo(Env, ms); + ML.Model.Save(transformer, null, ms); ms.Position = 0; - var loadedTransformer = TransformerChain.LoadFrom(Env, ms); + var loadedTransformer = ML.Model.Load(ms, out DataViewSchema schema); var result = loadedTransformer.Transform(dataView); Assert.Equal(2, result.Schema.Count); Assert.Equal("A", result.Schema[0].Name); @@ -200,9 +201,9 @@ void TestSelectSavingAndLoadingWithNoKeepHidden() var transformer = est.Fit(dataView); using (var ms = new MemoryStream()) { - transformer.SaveTo(Env, ms); + ML.Model.Save(transformer, null, ms); ms.Position = 0; - var loadedTransformer = TransformerChain.LoadFrom(Env, ms); + var loadedTransformer = ML.Model.Load(ms, out DataViewSchema schema); var result = loadedTransformer.Transform(dataView); Assert.Equal(2, result.Schema.Count); Assert.Equal("A", result.Schema[0].Name); diff --git a/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs b/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs index c98617c8a8..fac58175f3 100644 --- a/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs @@ -6,6 +6,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; +using Microsoft.Data.DataView; using Microsoft.ML.Data; using Microsoft.ML.Model; using Microsoft.ML.RunTests; @@ -600,9 +601,9 @@ void TestSavingAndLoading() var transformer = est.Fit(dataView); using (var ms = new MemoryStream()) { - transformer.SaveTo(Env, ms); + ML.Model.Save(transformer, null, ms); ms.Position = 0; - var loadedTransformer = TransformerChain.LoadFrom(Env, ms); + var loadedTransformer = ML.Model.Load(ms, out DataViewSchema schema); var result = loadedTransformer.Transform(dataView); Assert.Equal(5, result.Schema.Count); Assert.True(result.Schema.TryGetColumnIndex("D", out int col)); diff --git a/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs index 1b7975df5c..1f75171284 100644 --- a/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs +++ b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.IO; +using Microsoft.Data.DataView; using Microsoft.ML.Data; using Microsoft.ML.TestFramework.Attributes; using Microsoft.ML.Transforms.TimeSeries; @@ -179,7 +180,7 @@ public void ChangePointDetectionWithSeasonalityPredictionEngineNoColumn() //with "engine". ITransformer model2 = null; using (var file = File.OpenRead(modelPath)) - model2 = TransformerChain.LoadFrom(ml, file); + model2 = ml.Model.Load(file, out DataViewSchema schema); //Raw score after state gets updated with two inputs. var engine2 = model2.CreateTimeSeriesPredictionFunction(ml); @@ -199,7 +200,7 @@ public void ChangePointDetectionWithSeasonalityPredictionEngineNoColumn() engine.CheckPoint(ml, modelPath + 1); ITransformer model3 = null; using (var file = File.OpenRead(modelPath + 1)) - model3 = TransformerChain.LoadFrom(ml, file); + model3 = ml.Model.Load(file, out DataViewSchema schema); //Load the model with state updated with just one input, then pass in the second input //and raw score should match the raw score obtained by passing the two input in the first model. @@ -266,7 +267,7 @@ public void ChangePointDetectionWithSeasonalityPredictionEngine() // Load Model 1. ITransformer model2 = null; using (var file = File.OpenRead(modelPath)) - model2 = TransformerChain.LoadFrom(ml, file); + model2 = ml.Model.Load(file, out DataViewSchema schema); //Predict and expect the same result after checkpointing(Prediction #2). engine = model2.CreateTimeSeriesPredictionFunction(ml); From 94eefa13fe1b07821367d2d2bae3fe2beceae1ce Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Mon, 18 Mar 2019 10:55:58 -0700 Subject: [PATCH 17/18] Fix build after rebase --- .../Model/ModelOperationsCatalog.cs | 12 ------------ test/Microsoft.ML.Functional.Tests/ModelLoading.cs | 1 - .../Scenarios/Api/CookbookSamples/CookbookSamples.cs | 1 - .../Api/Estimators/TrainSaveModelAndPredict.cs | 1 - .../TrainerEstimators/MatrixFactorizationTests.cs | 1 - .../Transformers/SelectColumnsTests.cs | 1 - .../Transformers/ValueMappingTests.cs | 1 - .../TimeSeriesDirectApi.cs | 1 - 8 files changed, 19 deletions(-) diff --git a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs index 251577b945..a7720185b0 100644 --- a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs +++ b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs @@ -5,7 +5,6 @@ using System; using System.IO; using System.Linq; -using Microsoft.Data.DataView; using Microsoft.ML.Data; using Microsoft.ML.Data.IO; using Microsoft.ML.Model; @@ -194,17 +193,6 @@ public ITransformer Load(Stream stream, out IDataLoader load return new TransformerChain(); } - /// - /// Load the model from a file path. - /// - /// Path to model. - /// The loaded model. - public ITransformer Load(string modelPath) - { - using (var stream = File.OpenRead(modelPath)) - return Load(stream); - } - /// /// The catalog of model explainability operations. /// diff --git a/test/Microsoft.ML.Functional.Tests/ModelLoading.cs b/test/Microsoft.ML.Functional.Tests/ModelLoading.cs index cd719cfa03..2adce45953 100644 --- a/test/Microsoft.ML.Functional.Tests/ModelLoading.cs +++ b/test/Microsoft.ML.Functional.Tests/ModelLoading.cs @@ -5,7 +5,6 @@ using System; using System.IO; using System.Linq; -using Microsoft.Data.DataView; using Microsoft.ML.Calibrators; using Microsoft.ML.Data; using Microsoft.ML.RunTests; diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs index cf8c900c54..c0060b9e81 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs @@ -7,7 +7,6 @@ using System.Collections.Immutable; using System.IO; using System.Linq; -using Microsoft.Data.DataView; using Microsoft.ML; using Microsoft.ML.Data; using Microsoft.ML.RunTests; diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs index 9e84f775b8..8d6dd718f8 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs @@ -4,7 +4,6 @@ using System.IO; using System.Linq; -using Microsoft.Data.DataView; using Microsoft.ML.RunTests; using Microsoft.ML.Trainers; using Xunit; diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs index 4d2fe7c860..c04f8af1e1 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs @@ -7,7 +7,6 @@ using System.IO; using System.Linq; using System.Runtime.InteropServices; -using Microsoft.Data.DataView; using Microsoft.ML.Data; using Microsoft.ML.RunTests; using Microsoft.ML.TestFramework.Attributes; diff --git a/test/Microsoft.ML.Tests/Transformers/SelectColumnsTests.cs b/test/Microsoft.ML.Tests/Transformers/SelectColumnsTests.cs index a6c46e8dc5..2aceefeac5 100644 --- a/test/Microsoft.ML.Tests/Transformers/SelectColumnsTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/SelectColumnsTests.cs @@ -4,7 +4,6 @@ using System; using System.IO; -using Microsoft.Data.DataView; using Microsoft.ML.Data; using Microsoft.ML.Model; using Microsoft.ML.RunTests; diff --git a/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs b/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs index fac58175f3..766eab0c8d 100644 --- a/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs @@ -6,7 +6,6 @@ using System.Collections.Generic; using System.IO; using System.Linq; -using Microsoft.Data.DataView; using Microsoft.ML.Data; using Microsoft.ML.Model; using Microsoft.ML.RunTests; diff --git a/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs index 1f75171284..fae0ec6276 100644 --- a/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs +++ b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs @@ -4,7 +4,6 @@ using System.Collections.Generic; using System.IO; -using Microsoft.Data.DataView; using Microsoft.ML.Data; using Microsoft.ML.TestFramework.Attributes; using Microsoft.ML.Transforms.TimeSeries; From d1dfcd29ef920e220b203fea2126e30d2e50d85d Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Mon, 18 Mar 2019 11:03:13 -0700 Subject: [PATCH 18/18] Code review comments --- .../Model/ModelOperationsCatalog.cs | 2 +- .../UnitTests/TestEntryPoints.cs | 2 +- test/Microsoft.ML.Functional.Tests/ModelLoading.cs | 12 ++++++------ test/Microsoft.ML.Tests/ImagesTests.cs | 2 +- .../Scenarios/Api/CookbookSamples/CookbookSamples.cs | 2 +- .../Api/CookbookSamples/CookbookSamplesDynamicApi.cs | 4 ++-- .../TrainerEstimators/MatrixFactorizationTests.cs | 2 +- test/Microsoft.ML.Tests/Transformers/ConvertTests.cs | 2 +- .../Transformers/CopyColumnEstimatorTests.cs | 2 +- .../Transformers/SelectColumnsTests.cs | 4 ++-- .../Transformers/ValueMappingTests.cs | 2 +- .../TimeSeriesDirectApi.cs | 6 +++--- 12 files changed, 21 insertions(+), 21 deletions(-) diff --git a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs index a7720185b0..559d93f202 100644 --- a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs +++ b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs @@ -180,7 +180,7 @@ public IDataLoader Load(Stream stream) /// A readable, seekable stream to load from. /// The data loader from the model stream. /// The transformer model from the model stream. - public ITransformer Load(Stream stream, out IDataLoader loader) + public ITransformer LoadWithDataLoader(Stream stream, out IDataLoader loader) { _env.CheckValue(stream, nameof(stream)); diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index c1183825ab..db57dd402c 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -5645,7 +5645,7 @@ public void LoadEntryPointModel() ITransformer loadedModel; using (var stream = File.OpenRead(modelPath)) { - loadedModel = ml.Model.Load(stream, out DataViewSchema inputSchema); + loadedModel = ml.Model.Load(stream, out var inputSchema); } } } diff --git a/test/Microsoft.ML.Functional.Tests/ModelLoading.cs b/test/Microsoft.ML.Functional.Tests/ModelLoading.cs index 2adce45953..54d5b6a5c5 100644 --- a/test/Microsoft.ML.Functional.Tests/ModelLoading.cs +++ b/test/Microsoft.ML.Functional.Tests/ModelLoading.cs @@ -72,22 +72,22 @@ public void LoadModelAndExtractPredictor() IDataLoader loadedCompositeLoader; ITransformer loadedTransformerModel1; using (var fs = File.OpenRead(modelAndSchemaPath)) - loadedTransformerModel = _ml.Model.Load(fs, out DataViewSchema loadedSchema); + loadedTransformerModel = _ml.Model.Load(fs, out var loadedSchema); using (var fs = File.OpenRead(compositeLoaderModelPath)) { // This model can be loaded either as a composite data loader, // a transformer model + an input schema, or a transformer model + a data loader. - var t = _ml.Model.Load(fs, out IDataLoader l); - var t1 = _ml.Model.Load(fs, out DataViewSchema s); + var t = _ml.Model.LoadWithDataLoader(fs, out IDataLoader l); + var t1 = _ml.Model.Load(fs, out var s); loadedCompositeLoader = _ml.Model.Load(fs); } using (var fs = File.OpenRead(loaderAndTransformerModelPath)) { // This model can be loaded either as a composite data loader, // a transformer model + an input schema, or a transformer model + a data loader. - var t = _ml.Model.Load(fs, out DataViewSchema s); + var t = _ml.Model.Load(fs, out var s); var c = _ml.Model.Load(fs); - loadedTransformerModel1 = _ml.Model.Load(fs, out IDataLoader l); + loadedTransformerModel1 = _ml.Model.LoadWithDataLoader(fs, out IDataLoader l); } var gam = ((loadedTransformerModel as ISingleFeaturePredictionTransformer).Model @@ -320,7 +320,7 @@ private void Load(string filename, out ITransformer loadedWithSchema, out DataVi loadedWithSchema = _ml.Model.Load(fs, out loadedSchema); try { - loadedWithLoader = _ml.Model.Load(fs, out loadedLoaderWithTransformer); + loadedWithLoader = _ml.Model.LoadWithDataLoader(fs, out loadedLoaderWithTransformer); } catch (Exception) { diff --git a/test/Microsoft.ML.Tests/ImagesTests.cs b/test/Microsoft.ML.Tests/ImagesTests.cs index b38d574326..21e05f2747 100644 --- a/test/Microsoft.ML.Tests/ImagesTests.cs +++ b/test/Microsoft.ML.Tests/ImagesTests.cs @@ -83,7 +83,7 @@ public void TestEstimatorSaveLoad() ML.Model.Save(model, null, fs); ITransformer model2; using (var fs = file.OpenReadStream()) - model2 = ML.Model.Load(fs, out DataViewSchema schema); + model2 = ML.Model.Load(fs, out var schema); var transformerChain = model2 as TransformerChain; Assert.NotNull(transformerChain); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs index c0060b9e81..2d3f765b68 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs @@ -157,7 +157,7 @@ private void TrainRegression(string trainDataPath, string testDataPath, string m // When you load the model, it's a 'dynamic' transformer. ITransformer loadedModel; using (var stream = File.OpenRead(modelPath)) - loadedModel = mlContext.Model.Load(stream, out DataViewSchema schema); + loadedModel = mlContext.Model.Load(stream, out var schema); } [Fact] diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs index 1ebbb16729..ff1c3e83d1 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs @@ -131,7 +131,7 @@ private void TrainRegression(string trainDataPath, string testDataPath, string m // When you load the model, it's a 'dynamic' transformer. ITransformer loadedModel; using (var stream = File.OpenRead(modelPath)) - loadedModel = mlContext.Model.Load(stream, out DataViewSchema schema); + loadedModel = mlContext.Model.Load(stream, out var schema); } [Fact] @@ -535,7 +535,7 @@ private static void RunEndToEnd(MLContext mlContext, IDataView trainData, string // Now we can load the model. ITransformer loadedModel; using (var fs = File.OpenRead(modelPath)) - loadedModel = newContext.Model.Load(fs, out DataViewSchema schema); + loadedModel = newContext.Model.Load(fs, out var schema); } public static IDataView PrepareData(MLContext mlContext, IDataView data) diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs index c04f8af1e1..bba3945931 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs @@ -505,7 +505,7 @@ public void MatrixFactorizationBackCompat() { using (var fs = File.OpenRead(modelPath)) { - model = ML.Model.Load(fs, out DataViewSchema schema); + model = ML.Model.Load(fs, out var schema); // This model was saved without the input schema. Assert.Null(schema); } diff --git a/test/Microsoft.ML.Tests/Transformers/ConvertTests.cs b/test/Microsoft.ML.Tests/Transformers/ConvertTests.cs index c8e17036e3..e0167a241e 100644 --- a/test/Microsoft.ML.Tests/Transformers/ConvertTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/ConvertTests.cs @@ -261,7 +261,7 @@ public void TypeConvertKeyBackCompatTest() using (var ch = Env.Start("load")) { using (var fs = File.OpenRead(modelPath)) - modelOld = ML.Model.Load(fs, out DataViewSchema schema); + modelOld = ML.Model.Load(fs, out var schema); } var outDataOld = modelOld.Transform(dataView); diff --git a/test/Microsoft.ML.Tests/Transformers/CopyColumnEstimatorTests.cs b/test/Microsoft.ML.Tests/Transformers/CopyColumnEstimatorTests.cs index bcf85537bc..2be1292e9d 100644 --- a/test/Microsoft.ML.Tests/Transformers/CopyColumnEstimatorTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/CopyColumnEstimatorTests.cs @@ -95,7 +95,7 @@ void TestSavingAndLoading() { env.Model.Save(transformer, null, ms); ms.Position = 0; - var loadedTransformer = env.Model.Load(ms, out DataViewSchema schema); + var loadedTransformer = env.Model.Load(ms, out var schema); var result = loadedTransformer.Transform(dataView); ValidateCopyColumnTransformer(result); } diff --git a/test/Microsoft.ML.Tests/Transformers/SelectColumnsTests.cs b/test/Microsoft.ML.Tests/Transformers/SelectColumnsTests.cs index 2aceefeac5..288430d8ea 100644 --- a/test/Microsoft.ML.Tests/Transformers/SelectColumnsTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/SelectColumnsTests.cs @@ -182,7 +182,7 @@ void TestSelectSavingAndLoading() { ML.Model.Save(transformer, null, ms); ms.Position = 0; - var loadedTransformer = ML.Model.Load(ms, out DataViewSchema schema); + var loadedTransformer = ML.Model.Load(ms, out var schema); var result = loadedTransformer.Transform(dataView); Assert.Equal(2, result.Schema.Count); Assert.Equal("A", result.Schema[0].Name); @@ -202,7 +202,7 @@ void TestSelectSavingAndLoadingWithNoKeepHidden() { ML.Model.Save(transformer, null, ms); ms.Position = 0; - var loadedTransformer = ML.Model.Load(ms, out DataViewSchema schema); + var loadedTransformer = ML.Model.Load(ms, out var schema); var result = loadedTransformer.Transform(dataView); Assert.Equal(2, result.Schema.Count); Assert.Equal("A", result.Schema[0].Name); diff --git a/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs b/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs index 766eab0c8d..9638a558de 100644 --- a/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs @@ -602,7 +602,7 @@ void TestSavingAndLoading() { ML.Model.Save(transformer, null, ms); ms.Position = 0; - var loadedTransformer = ML.Model.Load(ms, out DataViewSchema schema); + var loadedTransformer = ML.Model.Load(ms, out var schema); var result = loadedTransformer.Transform(dataView); Assert.Equal(5, result.Schema.Count); Assert.True(result.Schema.TryGetColumnIndex("D", out int col)); diff --git a/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs index fae0ec6276..05055b5613 100644 --- a/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs +++ b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs @@ -179,7 +179,7 @@ public void ChangePointDetectionWithSeasonalityPredictionEngineNoColumn() //with "engine". ITransformer model2 = null; using (var file = File.OpenRead(modelPath)) - model2 = ml.Model.Load(file, out DataViewSchema schema); + model2 = ml.Model.Load(file, out var schema); //Raw score after state gets updated with two inputs. var engine2 = model2.CreateTimeSeriesPredictionFunction(ml); @@ -199,7 +199,7 @@ public void ChangePointDetectionWithSeasonalityPredictionEngineNoColumn() engine.CheckPoint(ml, modelPath + 1); ITransformer model3 = null; using (var file = File.OpenRead(modelPath + 1)) - model3 = ml.Model.Load(file, out DataViewSchema schema); + model3 = ml.Model.Load(file, out var schema); //Load the model with state updated with just one input, then pass in the second input //and raw score should match the raw score obtained by passing the two input in the first model. @@ -266,7 +266,7 @@ public void ChangePointDetectionWithSeasonalityPredictionEngine() // Load Model 1. ITransformer model2 = null; using (var file = File.OpenRead(modelPath)) - model2 = ml.Model.Load(file, out DataViewSchema schema); + model2 = ml.Model.Load(file, out var schema); //Predict and expect the same result after checkpointing(Prediction #2). engine = model2.CreateTimeSeriesPredictionFunction(ml);