-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Add API to save/load models with their input schema #2850
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,13 +45,34 @@ protected SubCatalogBase(ModelOperationsCatalog owner) | |
/// <param name="stream">A writeable, seekable stream to save to.</param> | ||
public void Save(ITransformer model, Stream stream) => model.SaveTo(Environment, stream); | ||
|
||
public void Save<TSource>(IDataLoader<TSource> model, Stream stream) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This I think this is not correct... we need to save both the loader (or failing that, input schema) as well as the transform model to a stream. Not two separate things. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I recall the mistake that was made with predictor models and transform models in entry-points, where we viewed these as single things saved to a stream, as opposed to having to capture the entire pipeline. This mistake was inherited by ML.NET v0.1, which led to all sorts of grief when people realized, yes, saving the loader is really freakin' important. So I think we're just making the same mistake again. A model isn't just a transform model, or a loader, or an input schema, it I think needs to capture all of them, and I think we need a change a bit more fundamental than merely adding a few methods here and there. We could have done that after v1.0. I think the problem is that the API here is just misleading. In reply to: 262638162 [](ancestors = 262638162) |
||
{ | ||
using (var rep = RepositoryWriter.CreateNew(stream)) | ||
{ | ||
ModelSaveContext.SaveModel(rep, model, "Model"); | ||
rep.Commit(); | ||
} | ||
} | ||
|
||
public void Save<TSource>(IDataLoader<TSource> loader, ITransformer model, Stream stream) => | ||
Save(new CompositeDataLoader<TSource, ITransformer>(loader, new TransformerChain<ITransformer>(model)), stream); | ||
|
||
/// <summary> | ||
/// Load the model from the stream. | ||
/// </summary> | ||
/// <param name="stream">A readable, seekable stream to load from.</param> | ||
/// <returns>The loaded model.</returns> | ||
public ITransformer Load(Stream stream) => TransformerChain.LoadFrom(Environment, stream); | ||
|
||
public IDataLoader<IMultiStreamSource> LoadAsCompositeDataLoader(Stream stream) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The subtle differences between usage here and usage of the above might be a bit hard of a pill to swallow. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this return |
||
{ | ||
using (var rep = RepositoryReader.Open(stream)) | ||
{ | ||
ModelLoadContext.LoadModel<IDataLoader<IMultiStreamSource>, SignatureLoadModel>(Environment, out var model, rep, "Model"); | ||
return model; | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// The catalog of model explainability operations. | ||
/// </summary> | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,6 @@ | |
using Microsoft.ML.CommandLine; | ||
using Microsoft.ML.Data; | ||
using Microsoft.ML.EntryPoints; | ||
using Microsoft.ML.Internal.Internallearn; | ||
using Microsoft.ML.Internal.Utilities; | ||
using Microsoft.ML.Model; | ||
using Microsoft.ML.Model.OnnxConverter; | ||
|
@@ -58,19 +57,19 @@ | |
"Naive Calibration Executor", | ||
NaiveCalibrator.LoaderSignature)] | ||
|
||
[assembly: LoadableClass(typeof(ValueMapperCalibratedModelParameters<IPredictorProducing<float>, ICalibrator>), null, typeof(SignatureLoadModel), | ||
[assembly: LoadableClass(typeof(ValueMapperCalibratedModelParameters<object, ICalibrator>), null, typeof(SignatureLoadModel), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This is a bit stronger medicine than what I was suggesting! (I don't object per se, but it did entail a bit more work.) I suppose it's a bit more "honest" than the old way -- marker interfaces are more or less a promise you can't keep -- but they are still useful on internal code where you have complete control over the implementation (as we do as they are internal). So I think this should be fine, I just wonder, are we fully comfortable with this? #Resolved |
||
"Calibrated Predictor Executor", | ||
ValueMapperCalibratedModelParameters<IPredictorProducing<float>, ICalibrator>.LoaderSignature, "BulkCaliPredExec")] | ||
|
||
[assembly: LoadableClass(typeof(FeatureWeightsCalibratedModelParameters<IPredictorWithFeatureWeights<float>, ICalibrator>), null, typeof(SignatureLoadModel), | ||
[assembly: LoadableClass(typeof(FeatureWeightsCalibratedModelParameters<object, ICalibrator>), null, typeof(SignatureLoadModel), | ||
"Feature Weights Calibrated Predictor Executor", | ||
FeatureWeightsCalibratedModelParameters<IPredictorWithFeatureWeights<float>, ICalibrator>.LoaderSignature)] | ||
|
||
[assembly: LoadableClass(typeof(ParameterMixingCalibratedModelParameters<IPredictorWithFeatureWeights<float>, ICalibrator>), null, typeof(SignatureLoadModel), | ||
[assembly: LoadableClass(typeof(ParameterMixingCalibratedModelParameters<object, ICalibrator>), null, typeof(SignatureLoadModel), | ||
"Parameter Mixing Calibrated Predictor Executor", | ||
ParameterMixingCalibratedModelParameters<IPredictorWithFeatureWeights<float>, ICalibrator>.LoaderSignature)] | ||
|
||
[assembly: LoadableClass(typeof(SchemaBindableCalibratedModelParameters<IPredictorProducing<float>, ICalibrator>), null, typeof(SignatureLoadModel), | ||
[assembly: LoadableClass(typeof(SchemaBindableCalibratedModelParameters<object, ICalibrator>), null, typeof(SignatureLoadModel), | ||
"Schema Bindable Calibrated Predictor", SchemaBindableCalibratedModelParameters<IPredictorProducing<float>, ICalibrator>.LoaderSignature)] | ||
|
||
[assembly: LoadableClass(typeof(void), typeof(Calibrate), null, typeof(SignatureEntryPointModule), "Calibrate")] | ||
|
@@ -147,8 +146,8 @@ internal interface ISelfCalibratingPredictor | |
[BestFriend] | ||
internal interface IWeaklyTypedCalibratedModelParameters | ||
{ | ||
IPredictorProducing<float> WeeklyTypedSubModel { get; } | ||
ICalibrator WeeklyTypedCalibrator { get; } | ||
IPredictorProducing<float> WeaklyTypedSubModel { get; } | ||
ICalibrator WeaklyTypedCalibrator { get; } | ||
} | ||
|
||
/// <summary> | ||
|
@@ -186,8 +185,8 @@ public abstract class CalibratedModelParametersBase<TSubModel, TCalibrator> : | |
public TCalibrator Calibrator { get; } | ||
|
||
// Type-unsafed accessors of strongly-typed members. | ||
IPredictorProducing<float> IWeaklyTypedCalibratedModelParameters.WeeklyTypedSubModel => (IPredictorProducing<float>)SubModel; | ||
ICalibrator IWeaklyTypedCalibratedModelParameters.WeeklyTypedCalibrator => Calibrator; | ||
IPredictorProducing<float> IWeaklyTypedCalibratedModelParameters.WeaklyTypedSubModel => (IPredictorProducing<float>)SubModel; | ||
ICalibrator IWeaklyTypedCalibratedModelParameters.WeaklyTypedCalibrator => Calibrator; | ||
|
||
PredictionKind IPredictor.PredictionKind => ((IPredictorProducing<float>)SubModel).PredictionKind; | ||
|
||
|
@@ -198,6 +197,7 @@ private protected CalibratedModelParametersBase(IHostEnvironment env, string nam | |
Host = env.Register(name); | ||
Host.CheckValue(predictor, nameof(predictor)); | ||
Host.CheckValue(calibrator, nameof(calibrator)); | ||
Host.Assert(predictor is IPredictorProducing<float>); | ||
|
||
SubModel = predictor; | ||
Calibrator = calibrator; | ||
|
@@ -270,7 +270,7 @@ internal abstract class ValueMapperCalibratedModelParametersBase<TSubModel, TCal | |
CalibratedModelParametersBase<TSubModel, TCalibrator>, | ||
IValueMapperDist, IFeatureContributionMapper, ICalculateFeatureContribution, | ||
IDistCanSavePfa, IDistCanSaveOnnx | ||
where TSubModel : class, IPredictorProducing<float> | ||
where TSubModel : class | ||
where TCalibrator : class, ICalibrator | ||
{ | ||
private readonly IValueMapper _mapper; | ||
|
@@ -380,7 +380,7 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string | |
[BestFriend] | ||
internal sealed class ValueMapperCalibratedModelParameters<TSubModel, TCalibrator> : | ||
ValueMapperCalibratedModelParametersBase<TSubModel, TCalibrator>, ICanSaveModel | ||
where TSubModel : class, IPredictorProducing<float> | ||
where TSubModel : class | ||
where TCalibrator : class, ICalibrator | ||
{ | ||
internal ValueMapperCalibratedModelParameters(IHostEnvironment env, TSubModel predictor, TCalibrator calibrator) | ||
|
@@ -442,7 +442,7 @@ internal sealed class FeatureWeightsCalibratedModelParameters<TSubModel, TCalibr | |
ValueMapperCalibratedModelParametersBase<TSubModel, TCalibrator>, | ||
IPredictorWithFeatureWeights<float>, | ||
ICanSaveModel | ||
where TSubModel : class, IPredictorWithFeatureWeights<float> | ||
where TSubModel : class | ||
where TCalibrator : class, ICalibrator | ||
{ | ||
private readonly IPredictorWithFeatureWeights<float> _featureWeights; | ||
|
@@ -451,7 +451,8 @@ internal FeatureWeightsCalibratedModelParameters(IHostEnvironment env, TSubModel | |
TCalibrator calibrator) | ||
: base(env, RegistrationName, predictor, calibrator) | ||
{ | ||
_featureWeights = predictor; | ||
Host.Assert(predictor is IPredictorWithFeatureWeights<float>); | ||
_featureWeights = predictor as IPredictorWithFeatureWeights<float>; | ||
} | ||
|
||
internal const string LoaderSignature = "FeatWCaliPredExec"; | ||
|
@@ -506,7 +507,7 @@ internal sealed class ParameterMixingCalibratedModelParameters<TSubModel, TCalib | |
IParameterMixer<float>, | ||
IPredictorWithFeatureWeights<float>, | ||
ICanSaveModel | ||
where TSubModel : class, IPredictorWithFeatureWeights<float> | ||
where TSubModel : class | ||
where TCalibrator : class, ICalibrator | ||
{ | ||
private readonly IPredictorWithFeatureWeights<float> _featureWeights; | ||
|
@@ -516,7 +517,8 @@ internal ParameterMixingCalibratedModelParameters(IHostEnvironment env, TSubMode | |
{ | ||
Host.Check(predictor is IParameterMixer<float>, "Predictor does not implement " + nameof(IParameterMixer<float>)); | ||
Host.Check(calibrator is IParameterMixer, "Calibrator does not implement " + nameof(IParameterMixer)); | ||
_featureWeights = predictor; | ||
Host.Assert(predictor is IPredictorWithFeatureWeights<float>); | ||
_featureWeights = predictor as IPredictorWithFeatureWeights<float>; | ||
} | ||
|
||
internal const string LoaderSignature = "PMixCaliPredExec"; | ||
|
@@ -538,7 +540,7 @@ private ParameterMixingCalibratedModelParameters(IHostEnvironment env, ModelLoad | |
{ | ||
Host.Check(SubModel is IParameterMixer<float>, "Predictor does not implement " + nameof(IParameterMixer)); | ||
Host.Check(SubModel is IPredictorWithFeatureWeights<float>, "Predictor does not implement " + nameof(IPredictorWithFeatureWeights<float>)); | ||
_featureWeights = SubModel; | ||
_featureWeights = SubModel as IPredictorWithFeatureWeights<float>; | ||
} | ||
|
||
private static ParameterMixingCalibratedModelParameters<TSubModel, TCalibrator> Create(IHostEnvironment env, ModelLoadContext ctx) | ||
|
@@ -587,7 +589,7 @@ IParameterMixer<float> IParameterMixer<float>.CombineParameters(IList<IParameter | |
[BestFriend] | ||
internal sealed class SchemaBindableCalibratedModelParameters<TSubModel, TCalibrator> : CalibratedModelParametersBase<TSubModel, TCalibrator>, ISchemaBindableMapper, ICanSaveModel, | ||
IBindableCanSavePfa, IBindableCanSaveOnnx, IFeatureContributionMapper | ||
where TSubModel : class, IPredictorProducing<float> | ||
where TSubModel : class | ||
where TCalibrator : class, ICalibrator | ||
{ | ||
private sealed class Bound : ISchemaBoundRowMapper | ||
|
@@ -702,14 +704,14 @@ private static VersionInfo GetVersionInfo() | |
internal SchemaBindableCalibratedModelParameters(IHostEnvironment env, TSubModel predictor, TCalibrator calibrator) | ||
: base(env, LoaderSignature, predictor, calibrator) | ||
{ | ||
_bindable = ScoreUtils.GetSchemaBindableMapper(Host, SubModel); | ||
_bindable = ScoreUtils.GetSchemaBindableMapper(Host, SubModel as IPredictorProducing<float>); | ||
_featureContribution = SubModel as IFeatureContributionMapper; | ||
} | ||
|
||
private SchemaBindableCalibratedModelParameters(IHostEnvironment env, ModelLoadContext ctx) | ||
: base(env, LoaderSignature, GetPredictor(env, ctx), GetCalibrator(env, ctx)) | ||
{ | ||
_bindable = ScoreUtils.GetSchemaBindableMapper(Host, SubModel); | ||
_bindable = ScoreUtils.GetSchemaBindableMapper(Host, SubModel as IPredictorProducing<float>); | ||
_featureContribution = SubModel as IFeatureContributionMapper; | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need xml comments for new public API. #Resolved