Skip to content

Add API to save/load models with their input schema #2850

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Core/Data/IEstimator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ internal bool TryFindColumn(string name, out Column column)
/// The 'data loader' takes a certain kind of input and turns it into an <see cref="IDataView"/>.
/// </summary>
/// <typeparam name="TSource">The type of input the loader takes.</typeparam>
public interface IDataLoader<in TSource>
public interface IDataLoader<in TSource> : ICanSaveModel
{
/// <summary>
/// Produce the data view from the specified input.
Expand Down
82 changes: 49 additions & 33 deletions src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@

using System.IO;
using Microsoft.Data.DataView;
using Microsoft.ML.Model;
using Microsoft.ML;
using Microsoft.ML.Data;

[assembly: LoadableClass(CompositeDataLoader<IMultiStreamSource, ITransformer>.Summary, typeof(CompositeDataLoader<IMultiStreamSource, ITransformer>), null, typeof(SignatureLoadModel),
"Composite Loader", CompositeDataLoader<IMultiStreamSource, ITransformer>.LoaderSignature)]

namespace Microsoft.ML.Data
{
Expand All @@ -15,6 +19,10 @@ namespace Microsoft.ML.Data
public sealed class CompositeDataLoader<TSource, TLastTransformer> : IDataLoader<TSource>
where TLastTransformer : class, ITransformer
{
private const string LoaderDirectory = "Loader";
private const string LegacyLoaderDirectory = "Reader";
private const string TransformerDirectory = TransformerChain.LoaderSignature;

/// <summary>
/// The underlying data loader.
/// </summary>
Expand All @@ -33,6 +41,24 @@ public CompositeDataLoader(IDataLoader<TSource> loader, TransformerChain<TLastTr
Transformer = transformerChain ?? new TransformerChain<TLastTransformer>();
}

private CompositeDataLoader(IHost host, ModelLoadContext ctx)
{
if (!ctx.LoadModelOrNull<IDataLoader<TSource>, SignatureLoadModel>(host, out Loader, LegacyLoaderDirectory))
ctx.LoadModel<IDataLoader<TSource>, SignatureLoadModel>(host, out Loader, LoaderDirectory);
ctx.LoadModel<TransformerChain<TLastTransformer>, SignatureLoadModel>(host, out Transformer, TransformerDirectory);
}

private static CompositeDataLoader<TSource, TLastTransformer> Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
IHost h = env.Register(LoaderSignature);

h.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());

return h.Apply("Loading Model", ch => new CompositeDataLoader<TSource, TLastTransformer>(h, ctx));
}

/// <summary>
/// Produce the data view from the specified input.
/// Note that <see cref="IDataView"/>'s are lazy, so no actual loading happens here, just schema validation.
Expand Down Expand Up @@ -62,6 +88,16 @@ public CompositeDataLoader<TSource, TNewLast> AppendTransformer<TNewLast>(TNewLa
return new CompositeDataLoader<TSource, TNewLast>(Loader, Transformer.Append(transformer));
}

void ICanSaveModel.Save(ModelSaveContext ctx)
{
Contracts.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());

ctx.SaveModel(Loader, LoaderDirectory);
ctx.SaveModel(Transformer, TransformerDirectory);
}

/// <summary>
/// Save the contents to a stream, as a "model file".
/// </summary>
Expand All @@ -76,47 +112,27 @@ public void SaveTo(IHostEnvironment env, Stream outputStream)
using (var rep = RepositoryWriter.CreateNew(outputStream, ch))
{
ch.Trace("Saving data loader");
ModelSaveContext.SaveModel(rep, Loader, "Reader");
ModelSaveContext.SaveModel(rep, Loader, LoaderDirectory);

ch.Trace("Saving transformer chain");
ModelSaveContext.SaveModel(rep, Transformer, TransformerChain.LoaderSignature);
ModelSaveContext.SaveModel(rep, Transformer, TransformerDirectory);
rep.Commit();
}
}
}
}

/// <summary>
/// Utility class to facilitate loading from a stream.
/// </summary>
[BestFriend]
internal static class CompositeDataLoader
{
/// <summary>
/// Save the contents to a stream, as a "model file".
/// </summary>
public static void SaveTo<TSource>(this IDataLoader<TSource> loader, IHostEnvironment env, Stream outputStream)
=> new CompositeDataLoader<TSource, ITransformer>(loader).SaveTo(env, outputStream);
internal const string Summary = "A loader that encapsulates a loader and a transformer chain.";

/// <summary>
/// Load the pipeline from stream.
/// </summary>
public static CompositeDataLoader<IMultiStreamSource, ITransformer> LoadFrom(IHostEnvironment env, Stream stream)
internal const string LoaderSignature = "CompositeLoader";
private static VersionInfo GetVersionInfo()
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(stream, nameof(stream));

env.Check(stream.CanRead && stream.CanSeek, "Need a readable and seekable stream to load");
using (var rep = RepositoryReader.Open(stream, env))
using (var ch = env.Start("Loading pipeline"))
{
ch.Trace("Loading data loader");
ModelLoadContext.LoadModel<IDataLoader<IMultiStreamSource>, SignatureLoadModel>(env, out var loader, rep, "Reader");

ch.Trace("Loader transformer chain");
ModelLoadContext.LoadModel<TransformerChain<ITransformer>, SignatureLoadModel>(env, out var transformerChain, rep, TransformerChain.LoaderSignature);
return new CompositeDataLoader<IMultiStreamSource, ITransformer>(loader, transformerChain);
}
return new VersionInfo(
modelSignature: "CMPSTLDR",
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(CompositeDataLoader<,>).Assembly.FullName);
}
}
}
6 changes: 4 additions & 2 deletions src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,22 @@
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;

[assembly: LoadableClass(TextLoader.Summary, typeof(ILegacyDataLoader), typeof(TextLoader), typeof(TextLoader.Options), typeof(SignatureDataLoader),
"Text Loader", "TextLoader", "Text", DocName = "loader/TextLoader.md")]

[assembly: LoadableClass(TextLoader.Summary, typeof(ILegacyDataLoader), typeof(TextLoader), null, typeof(SignatureLoadDataLoader),
"Text Loader", TextLoader.LoaderSignature)]

[assembly: LoadableClass(TextLoader.Summary, typeof(TextLoader), null, typeof(SignatureLoadModel),
"Text Loader", TextLoader.LoaderSignature)]

namespace Microsoft.ML.Data
{
/// <summary>
/// Loads a text file into an IDataView. Supports basic mapping from input columns to <see cref="IDataView"/> columns.
/// </summary>
public sealed partial class TextLoader : IDataLoader<IMultiStreamSource>, ICanSaveModel
public sealed partial class TextLoader : IDataLoader<IMultiStreamSource>
{
/// <summary>
/// Describes how an input column should be mapped to an <see cref="IDataView"/> column.
Expand Down
4 changes: 3 additions & 1 deletion src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,9 @@ public static TransformerChain<ITransformer> LoadFrom(IHostEnvironment env, Stre
{
try
{
ModelLoadContext.LoadModel<TransformerChain<ITransformer>, SignatureLoadModel>(env, out var transformerChain, rep, LoaderSignature);
ModelLoadContext.LoadModelOrNull<TransformerChain<ITransformer>, SignatureLoadModel>(env, out var transformerChain, rep, LoaderSignature);
if (transformerChain == null)
ModelLoadContext.LoadModel<TransformerChain<ITransformer>, SignatureLoadModel>(env, out transformerChain, rep, $@"Model\{LoaderSignature}");
return transformerChain;
}
catch (FormatException ex)
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/EntryPoints/PredictorModelImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ internal override string[] GetLabelInfo(IHostEnvironment env, out DataViewType l
var calibrated = predictor as IWeaklyTypedCalibratedModelParameters;
while (calibrated != null)
{
predictor = calibrated.WeeklyTypedSubModel;
predictor = calibrated.WeaklyTypedSubModel;
calibrated = predictor as IWeaklyTypedCalibratedModelParameters;
}
var canGetTrainingLabelNames = predictor as ICanGetTrainingLabelNames;
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/EntryPoints/SummarizePredictor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ internal static IDataView GetSummaryAndStats(IHostEnvironment env, IPredictor pr
var calibrated = predictor as IWeaklyTypedCalibratedModelParameters;
while (calibrated != null)
{
predictor = calibrated.WeeklyTypedSubModel;
predictor = calibrated.WeaklyTypedSubModel;
calibrated = predictor as IWeaklyTypedCalibratedModelParameters;
}

Expand Down
21 changes: 21 additions & 0 deletions src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,34 @@ protected SubCatalogBase(ModelOperationsCatalog owner)
/// <param name="stream">A writeable, seekable stream to save to.</param>
public void Save(ITransformer model, Stream stream) => model.SaveTo(Environment, stream);

public void Save<TSource>(IDataLoader<TSource> model, Stream stream)
Copy link
Member

@eerhardt eerhardt Mar 5, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need xml comments for new public API. #Resolved

Copy link
Contributor

@TomFinley TomFinley Mar 5, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This I think this is not correct... we need to save both the loader (or failing that, input schema) as well as the transform model to a stream. Not two separate things. #Resolved

Copy link
Contributor

@TomFinley TomFinley Mar 5, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recall the mistake that was made with predictor models and transform models in entry-points, where we viewed these as single things saved to a stream, as opposed to having to capture the entire pipeline. This mistake was inherited by ML.NET v0.1, which led to all sorts of grief when people realized, yes, saving the loader is really freakin' important. So I think we're just making the same mistake again. A model isn't just a transform model, or a loader, or an input schema, it I think needs to capture all of them, and I think we need a change a bit more fundamental than merely adding a few methods here and there. We could have done that after v1.0. I think the problem is that the API here is just misleading.


In reply to: 262638162 [](ancestors = 262638162)

{
using (var rep = RepositoryWriter.CreateNew(stream))
{
ModelSaveContext.SaveModel(rep, model, "Model");
rep.Commit();
}
}

public void Save<TSource>(IDataLoader<TSource> loader, ITransformer model, Stream stream) =>
Save(new CompositeDataLoader<TSource, ITransformer>(loader, new TransformerChain<ITransformer>(model)), stream);

/// <summary>
/// Load the model from the stream.
/// </summary>
/// <param name="stream">A readable, seekable stream to load from.</param>
/// <returns>The loaded model.</returns>
public ITransformer Load(Stream stream) => TransformerChain.LoadFrom(Environment, stream);

public IDataLoader<IMultiStreamSource> LoadAsCompositeDataLoader(Stream stream)
Copy link
Contributor

@TomFinley TomFinley Mar 5, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IDataLoader [](start = 15, length = 11)

The subtle differences between usage here and usage of the above might be a bit hard of a pill to swallow. #Resolved

Copy link
Member

@eerhardt eerhardt Mar 5, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this return CompositeDataLoader instead of IDataLoader? The name is: LoadAsCompositeDataLoader, which makes me think it should return a CompositeDataLoader #Resolved

{
using (var rep = RepositoryReader.Open(stream))
{
ModelLoadContext.LoadModel<IDataLoader<IMultiStreamSource>, SignatureLoadModel>(Environment, out var model, rep, "Model");
return model;
}
}

/// <summary>
/// The catalog of model explainability operations.
/// </summary>
Expand Down
40 changes: 21 additions & 19 deletions src/Microsoft.ML.Data/Prediction/Calibrator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Model.OnnxConverter;
Expand Down Expand Up @@ -58,19 +57,19 @@
"Naive Calibration Executor",
NaiveCalibrator.LoaderSignature)]

[assembly: LoadableClass(typeof(ValueMapperCalibratedModelParameters<IPredictorProducing<float>, ICalibrator>), null, typeof(SignatureLoadModel),
[assembly: LoadableClass(typeof(ValueMapperCalibratedModelParameters<object, ICalibrator>), null, typeof(SignatureLoadModel),
Copy link
Contributor

@TomFinley TomFinley Mar 5, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

object [](start = 69, length = 6)

This is a bit stronger medicine than what I was suggesting! (I don't object per se, but it did entail a bit more work.) I suppose it's a bit more "honest" than the old way -- marker interfaces are more or less a promise you can't keep -- but they are still useful on internal code where you have complete control over the implementation (as we do as they are internal).

So I think this should be fine, I just wonder, are we fully comfortable with this? #Resolved

"Calibrated Predictor Executor",
ValueMapperCalibratedModelParameters<IPredictorProducing<float>, ICalibrator>.LoaderSignature, "BulkCaliPredExec")]

[assembly: LoadableClass(typeof(FeatureWeightsCalibratedModelParameters<IPredictorWithFeatureWeights<float>, ICalibrator>), null, typeof(SignatureLoadModel),
[assembly: LoadableClass(typeof(FeatureWeightsCalibratedModelParameters<object, ICalibrator>), null, typeof(SignatureLoadModel),
"Feature Weights Calibrated Predictor Executor",
FeatureWeightsCalibratedModelParameters<IPredictorWithFeatureWeights<float>, ICalibrator>.LoaderSignature)]

[assembly: LoadableClass(typeof(ParameterMixingCalibratedModelParameters<IPredictorWithFeatureWeights<float>, ICalibrator>), null, typeof(SignatureLoadModel),
[assembly: LoadableClass(typeof(ParameterMixingCalibratedModelParameters<object, ICalibrator>), null, typeof(SignatureLoadModel),
"Parameter Mixing Calibrated Predictor Executor",
ParameterMixingCalibratedModelParameters<IPredictorWithFeatureWeights<float>, ICalibrator>.LoaderSignature)]

[assembly: LoadableClass(typeof(SchemaBindableCalibratedModelParameters<IPredictorProducing<float>, ICalibrator>), null, typeof(SignatureLoadModel),
[assembly: LoadableClass(typeof(SchemaBindableCalibratedModelParameters<object, ICalibrator>), null, typeof(SignatureLoadModel),
"Schema Bindable Calibrated Predictor", SchemaBindableCalibratedModelParameters<IPredictorProducing<float>, ICalibrator>.LoaderSignature)]

[assembly: LoadableClass(typeof(void), typeof(Calibrate), null, typeof(SignatureEntryPointModule), "Calibrate")]
Expand Down Expand Up @@ -147,8 +146,8 @@ internal interface ISelfCalibratingPredictor
[BestFriend]
internal interface IWeaklyTypedCalibratedModelParameters
{
IPredictorProducing<float> WeeklyTypedSubModel { get; }
ICalibrator WeeklyTypedCalibrator { get; }
IPredictorProducing<float> WeaklyTypedSubModel { get; }
ICalibrator WeaklyTypedCalibrator { get; }
}

/// <summary>
Expand Down Expand Up @@ -186,8 +185,8 @@ public abstract class CalibratedModelParametersBase<TSubModel, TCalibrator> :
public TCalibrator Calibrator { get; }

// Type-unsafed accessors of strongly-typed members.
IPredictorProducing<float> IWeaklyTypedCalibratedModelParameters.WeeklyTypedSubModel => (IPredictorProducing<float>)SubModel;
ICalibrator IWeaklyTypedCalibratedModelParameters.WeeklyTypedCalibrator => Calibrator;
IPredictorProducing<float> IWeaklyTypedCalibratedModelParameters.WeaklyTypedSubModel => (IPredictorProducing<float>)SubModel;
ICalibrator IWeaklyTypedCalibratedModelParameters.WeaklyTypedCalibrator => Calibrator;

PredictionKind IPredictor.PredictionKind => ((IPredictorProducing<float>)SubModel).PredictionKind;

Expand All @@ -198,6 +197,7 @@ private protected CalibratedModelParametersBase(IHostEnvironment env, string nam
Host = env.Register(name);
Host.CheckValue(predictor, nameof(predictor));
Host.CheckValue(calibrator, nameof(calibrator));
Host.Assert(predictor is IPredictorProducing<float>);

SubModel = predictor;
Calibrator = calibrator;
Expand Down Expand Up @@ -270,7 +270,7 @@ internal abstract class ValueMapperCalibratedModelParametersBase<TSubModel, TCal
CalibratedModelParametersBase<TSubModel, TCalibrator>,
IValueMapperDist, IFeatureContributionMapper, ICalculateFeatureContribution,
IDistCanSavePfa, IDistCanSaveOnnx
where TSubModel : class, IPredictorProducing<float>
where TSubModel : class
where TCalibrator : class, ICalibrator
{
private readonly IValueMapper _mapper;
Expand Down Expand Up @@ -380,7 +380,7 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string
[BestFriend]
internal sealed class ValueMapperCalibratedModelParameters<TSubModel, TCalibrator> :
ValueMapperCalibratedModelParametersBase<TSubModel, TCalibrator>, ICanSaveModel
where TSubModel : class, IPredictorProducing<float>
where TSubModel : class
where TCalibrator : class, ICalibrator
{
internal ValueMapperCalibratedModelParameters(IHostEnvironment env, TSubModel predictor, TCalibrator calibrator)
Expand Down Expand Up @@ -442,7 +442,7 @@ internal sealed class FeatureWeightsCalibratedModelParameters<TSubModel, TCalibr
ValueMapperCalibratedModelParametersBase<TSubModel, TCalibrator>,
IPredictorWithFeatureWeights<float>,
ICanSaveModel
where TSubModel : class, IPredictorWithFeatureWeights<float>
where TSubModel : class
where TCalibrator : class, ICalibrator
{
private readonly IPredictorWithFeatureWeights<float> _featureWeights;
Expand All @@ -451,7 +451,8 @@ internal FeatureWeightsCalibratedModelParameters(IHostEnvironment env, TSubModel
TCalibrator calibrator)
: base(env, RegistrationName, predictor, calibrator)
{
_featureWeights = predictor;
Host.Assert(predictor is IPredictorWithFeatureWeights<float>);
_featureWeights = predictor as IPredictorWithFeatureWeights<float>;
}

internal const string LoaderSignature = "FeatWCaliPredExec";
Expand Down Expand Up @@ -506,7 +507,7 @@ internal sealed class ParameterMixingCalibratedModelParameters<TSubModel, TCalib
IParameterMixer<float>,
IPredictorWithFeatureWeights<float>,
ICanSaveModel
where TSubModel : class, IPredictorWithFeatureWeights<float>
where TSubModel : class
where TCalibrator : class, ICalibrator
{
private readonly IPredictorWithFeatureWeights<float> _featureWeights;
Expand All @@ -516,7 +517,8 @@ internal ParameterMixingCalibratedModelParameters(IHostEnvironment env, TSubMode
{
Host.Check(predictor is IParameterMixer<float>, "Predictor does not implement " + nameof(IParameterMixer<float>));
Host.Check(calibrator is IParameterMixer, "Calibrator does not implement " + nameof(IParameterMixer));
_featureWeights = predictor;
Host.Assert(predictor is IPredictorWithFeatureWeights<float>);
_featureWeights = predictor as IPredictorWithFeatureWeights<float>;
}

internal const string LoaderSignature = "PMixCaliPredExec";
Expand All @@ -538,7 +540,7 @@ private ParameterMixingCalibratedModelParameters(IHostEnvironment env, ModelLoad
{
Host.Check(SubModel is IParameterMixer<float>, "Predictor does not implement " + nameof(IParameterMixer));
Host.Check(SubModel is IPredictorWithFeatureWeights<float>, "Predictor does not implement " + nameof(IPredictorWithFeatureWeights<float>));
_featureWeights = SubModel;
_featureWeights = SubModel as IPredictorWithFeatureWeights<float>;
}

private static ParameterMixingCalibratedModelParameters<TSubModel, TCalibrator> Create(IHostEnvironment env, ModelLoadContext ctx)
Expand Down Expand Up @@ -587,7 +589,7 @@ IParameterMixer<float> IParameterMixer<float>.CombineParameters(IList<IParameter
[BestFriend]
internal sealed class SchemaBindableCalibratedModelParameters<TSubModel, TCalibrator> : CalibratedModelParametersBase<TSubModel, TCalibrator>, ISchemaBindableMapper, ICanSaveModel,
IBindableCanSavePfa, IBindableCanSaveOnnx, IFeatureContributionMapper
where TSubModel : class, IPredictorProducing<float>
where TSubModel : class
where TCalibrator : class, ICalibrator
{
private sealed class Bound : ISchemaBoundRowMapper
Expand Down Expand Up @@ -702,14 +704,14 @@ private static VersionInfo GetVersionInfo()
internal SchemaBindableCalibratedModelParameters(IHostEnvironment env, TSubModel predictor, TCalibrator calibrator)
: base(env, LoaderSignature, predictor, calibrator)
{
_bindable = ScoreUtils.GetSchemaBindableMapper(Host, SubModel);
_bindable = ScoreUtils.GetSchemaBindableMapper(Host, SubModel as IPredictorProducing<float>);
_featureContribution = SubModel as IFeatureContributionMapper;
}

private SchemaBindableCalibratedModelParameters(IHostEnvironment env, ModelLoadContext ctx)
: base(env, LoaderSignature, GetPredictor(env, ctx), GetCalibrator(env, ctx))
{
_bindable = ScoreUtils.GetSchemaBindableMapper(Host, SubModel);
_bindable = ScoreUtils.GetSchemaBindableMapper(Host, SubModel as IPredictorProducing<float>);
_featureContribution = SubModel as IFeatureContributionMapper;
}

Expand Down
Loading