-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Remove model saving/loading inconsistencies #3044
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0bbc58c
f069ceb
4eecb8b
6f9cc87
7e72431
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,63 +29,60 @@ internal ModelOperationsCatalog(IHostEnvironment env) | |
} | ||
|
||
/// <summary> | ||
/// Save the model to the stream. | ||
/// Save a transformer model and the loader used to create its input data to the stream. | ||
/// </summary> | ||
/// <param name="model">The trained model to be saved.</param> | ||
/// <param name="model">The trained model to be saved. Note that this can be <see langword="null"/>, as a shorthand | ||
/// for an empty transformer chain. Upon loading with <see cref="LoadWithDataLoader(Stream, out IDataLoader{IMultiStreamSource})"/> | ||
/// the returned value will be an empty <see cref="TransformerChain{TLastTransformer}"/>.</param> | ||
/// <param name="loader">The loader that was used to create data to train the model.</param> | ||
/// <param name="stream">A writeable, seekable stream to save to.</param> | ||
public void Save<TSource>(IDataLoader<TSource> model, Stream stream) | ||
public void Save<TSource>(ITransformer model, IDataLoader<TSource> loader, Stream stream) | ||
{ | ||
_env.CheckValue(model, nameof(model)); | ||
_env.CheckValue(loader, nameof(loader)); | ||
_env.CheckValueOrNull(model); | ||
_env.CheckValue(stream, nameof(stream)); | ||
|
||
// For the sake of consistency of this API specifically, when called upon we save any transformer | ||
// in a single element transformer chain. | ||
var chainedModel = model == null ? null : new TransformerChain<ITransformer>(model); | ||
TomFinley marked this conversation as resolved.
Show resolved
Hide resolved
|
||
var compositeLoader = new CompositeDataLoader<TSource, ITransformer>(loader, chainedModel); | ||
|
||
using (var rep = RepositoryWriter.CreateNew(stream)) | ||
{ | ||
ModelSaveContext.SaveModel(rep, model, null); | ||
ModelSaveContext.SaveModel(rep, compositeLoader, null); | ||
rep.Commit(); | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// Save the model to the file. | ||
/// </summary> | ||
/// <param name="model">The trained model to be saved.</param> | ||
/// <param name="filePath">Path where model should be saved.</param> | ||
public void Save<TSource>(IDataLoader<TSource> model, string filePath) | ||
{ | ||
using (var stream = File.Create(filePath)) | ||
Save(model, stream); | ||
} | ||
|
||
/// <summary> | ||
/// Save a transformer model and the loader used to create its input data to the stream. | ||
/// </summary> | ||
/// <param name="loader">The loader that was used to create data to train the model</param> | ||
/// <param name="model">The trained model to be saved</param> | ||
/// <param name="stream">A writeable, seekable stream to save to.</param> | ||
public void Save<TSource>(IDataLoader<TSource> loader, ITransformer model, Stream stream) => | ||
Save(new CompositeDataLoader<TSource, ITransformer>(loader, new TransformerChain<ITransformer>(model)), stream); | ||
|
||
/// <summary> | ||
/// Save a transformer model and the loader used to create its input data to the file. | ||
/// </summary> | ||
/// <param name="loader">The loader that was used to create data to train the model</param> | ||
/// <param name="model">The trained model to be saved</param> | ||
/// <param name="model">The trained model to be saved. Note that this can be <see langword="null"/>, as a shorthand | ||
/// for an empty transformer chain. Upon loading with <see cref="LoadWithDataLoader(Stream, out IDataLoader{IMultiStreamSource})"/> | ||
/// the returned value will be an empty <see cref="TransformerChain{TLastTransformer}"/>.</param> | ||
/// <param name="loader">The loader that was used to create data to train the model.</param> | ||
/// <param name="filePath">Path where model should be saved.</param> | ||
public void Save<TSource>(IDataLoader<TSource> loader, ITransformer model, string filePath) | ||
public void Save<TSource>(ITransformer model, IDataLoader<TSource> loader, string filePath) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm also curious about whether we agree that what is effectively the descriptor of input to the model should come after the model itself. The API had been inconsistent, in that in one save method we had loader/model (so input to model came first), whereas in another we had model/schema (so input to model came second). In all cases I feel like the load methods should always just return I kept it as a separate commit in case we wanted to go the other way, maybe... so in that case, the schema. |
||
{ | ||
_env.CheckValueOrNull(model); | ||
_env.CheckValue(loader, nameof(loader)); | ||
TomFinley marked this conversation as resolved.
Show resolved
Hide resolved
eerhardt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
_env.CheckNonEmpty(filePath, nameof(filePath)); | ||
|
||
using (var stream = File.Create(filePath)) | ||
Save(loader, model, stream); | ||
Save(model, loader, stream); | ||
} | ||
|
||
/// <summary> | ||
/// Save a transformer model and the schema of the data that was used to train it to the stream. | ||
/// </summary> | ||
/// <param name="model">The trained model to be saved.</param> | ||
/// <param name="inputSchema">The schema of the input to the transformer. This can be null.</param> | ||
/// <param name="model">The trained model to be saved. Note that this can be <see langword="null"/>, as a shorthand | ||
/// for an empty transformer chain. Upon loading with <see cref="Load(Stream, out DataViewSchema)"/> the returned value will | ||
/// be an empty <see cref="TransformerChain{TLastTransformer}"/>.</param> | ||
/// <param name="inputSchema">The schema of the input to the transformer. This can be <see langword="null"/>.</param> | ||
/// <param name="stream">A writeable, seekable stream to save to.</param> | ||
public void Save(ITransformer model, DataViewSchema inputSchema, Stream stream) | ||
{ | ||
_env.CheckValue(model, nameof(model)); | ||
_env.CheckValueOrNull(model); | ||
_env.CheckValueOrNull(inputSchema); | ||
_env.CheckValue(stream, nameof(stream)); | ||
|
||
|
@@ -100,11 +97,17 @@ public void Save(ITransformer model, DataViewSchema inputSchema, Stream stream) | |
/// <summary> | ||
/// Save a transformer model and the schema of the data that was used to train it to the file. | ||
/// </summary> | ||
/// <param name="model">The trained model to be saved.</param> | ||
/// <param name="inputSchema">The schema of the input to the transformer. This can be null.</param> | ||
/// <param name="model">The trained model to be saved. Note that this can be <see langword="null"/>, as a shorthand | ||
/// for an empty transformer chain. Upon loading with <see cref="Load(Stream, out DataViewSchema)"/> the returned value will | ||
/// be an empty <see cref="TransformerChain{TLastTransformer}"/>.</param> | ||
/// <param name="inputSchema">The schema of the input to the transformer. This can be <see langword="null"/>.</param> | ||
/// <param name="filePath">Path where model should be saved.</param> | ||
public void Save(ITransformer model, DataViewSchema inputSchema, string filePath) | ||
{ | ||
_env.CheckValueOrNull(model); | ||
_env.CheckValueOrNull(inputSchema); | ||
_env.CheckNonEmpty(filePath, nameof(filePath)); | ||
|
||
using (var stream = File.Create(filePath)) | ||
Save(model, inputSchema, stream); | ||
} | ||
|
@@ -126,11 +129,11 @@ private void SaveInputSchema(DataViewSchema inputSchema, RepositoryWriter rep) | |
} | ||
|
||
/// <summary> | ||
/// Load the model and its input schema from the stream. | ||
/// Load the model and its input schema from a stream. | ||
/// </summary> | ||
/// <param name="stream">A readable, seekable stream to load from.</param> | ||
/// <param name="inputSchema">Will contain the input schema for the model. If the model was saved using older APIs | ||
/// it may not contain an input schema, in this case <paramref name="inputSchema"/> will be null.</param> | ||
/// <param name="inputSchema">Will contain the input schema for the model. If the model was saved without | ||
/// any description of the input, there will be no input schema. In this case this can be <see langword="null"/>.</param> | ||
/// <returns>The loaded model.</returns> | ||
public ITransformer Load(Stream stream, out DataViewSchema inputSchema) | ||
{ | ||
|
@@ -171,57 +174,100 @@ public ITransformer Load(Stream stream, out DataViewSchema inputSchema) | |
throw _env.Except(ex, "Could not load legacy format model"); | ||
} | ||
} | ||
if (dataLoader is CompositeDataLoader<IMultiStreamSource, ITransformer> composite) | ||
{ | ||
inputSchema = composite.Loader.GetOutputSchema(); | ||
return composite.Transformer; | ||
} | ||
var transformer = DecomposeLoader(ref dataLoader); | ||
inputSchema = dataLoader.GetOutputSchema(); | ||
return new TransformerChain<ITransformer>(); | ||
return transformer; | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// Load the model and its input schema from a file. | ||
/// </summary> | ||
/// <param name="filePath">Path to a file where the model should be read from.</param> | ||
/// <param name="inputSchema">Will contain the input schema for the model. If the model was saved without | ||
/// any description of the input, there will be no input schema. In this case this can be <see langword="null"/>.</param> | ||
/// <returns>The loaded model.</returns> | ||
public ITransformer Load(string filePath, out DataViewSchema inputSchema) | ||
{ | ||
_env.CheckNonEmpty(filePath, nameof(filePath)); | ||
|
||
using (var stream = File.OpenRead(filePath)) | ||
return Load(stream, out inputSchema); | ||
} | ||
|
||
/// <summary> | ||
/// Given a loader, test try to "decompose" it into a source loader, and its transform if any. | ||
/// If necessary an empty chain will be created to stand in for the trivial transformation; it | ||
/// should never return <see langword="null"/>. | ||
/// </summary> | ||
private ITransformer DecomposeLoader(ref IDataLoader<IMultiStreamSource> loader) | ||
{ | ||
_env.AssertValue(loader); | ||
|
||
if (loader is CompositeDataLoader<IMultiStreamSource, ITransformer> composite) | ||
{ | ||
loader = composite.Loader; | ||
var chain = composite.Transformer; | ||
// The save method corresponding to this load method encapsulates the input ITransformer | ||
// into a single-element transformer chain. If it is that sort, we guess that it is in fact | ||
// that sort, and so return it. | ||
var accessor = (ITransformerChainAccessor)chain; | ||
if (accessor.Transformers.Length == 1) | ||
return accessor.Transformers[0]; | ||
// If it is some other length than 1 due to, say, some legacy model saving, just return that | ||
// chain. Using the above API this is not possible, since the chain saved will always be of length | ||
// one, but older APIs behaved differently so we should retain flexibility with those schemes. | ||
// (Those schemes are BTW by no means incorrect, they just aren't what the API in this particular | ||
// class will specifically do.) | ||
return chain; | ||
} | ||
// Maybe we have no transformer stored. Rather than return null, we prefer to return the | ||
// empty "trivial" transformer chain. | ||
return new TransformerChain<ITransformer>(); | ||
} | ||
|
||
/// <summary> | ||
/// Load the model and its input schema from the stream. | ||
/// Load a transformer model and a data loader model from a stream. | ||
/// </summary> | ||
/// <param name="stream">A readable, seekable stream to load from.</param> | ||
/// <returns>A model of type <see cref="CompositeDataLoader{IMultiStreamSource, ITransformer}"/> containing the loader | ||
/// and the transformer chain.</returns> | ||
public IDataLoader<IMultiStreamSource> Load(Stream stream) | ||
/// <param name="loader">The data loader from the model stream. Note that if there is no data loader, | ||
/// this method will throw an exception. The scenario where no loader is stored in the stream should | ||
/// be handled instead using the <see cref="Load(Stream, out DataViewSchema)"/> method.</param> | ||
/// <returns>The transformer model from the model stream.</returns> | ||
public ITransformer LoadWithDataLoader(Stream stream, out IDataLoader<IMultiStreamSource> loader) | ||
{ | ||
_env.CheckValue(stream, nameof(stream)); | ||
|
||
using (var rep = RepositoryReader.Open(stream)) | ||
{ | ||
try | ||
{ | ||
ModelLoadContext.LoadModel<IDataLoader<IMultiStreamSource>, SignatureLoadModel>(_env, out var model, rep, null); | ||
return model; | ||
ModelLoadContext.LoadModel<IDataLoader<IMultiStreamSource>, SignatureLoadModel>(_env, out loader, rep, null); | ||
return DecomposeLoader(ref loader); | ||
} | ||
catch (Exception ex) | ||
{ | ||
throw _env.Except(ex, "Model does not contain an IDataLoader"); | ||
throw _env.Except(ex, "Model does not contain an " + nameof(IDataLoader<IMultiStreamSource>) + | ||
". Perhaps this was saved with an " + nameof(DataViewSchema) + ", or even no information on its input at all. " + | ||
"Consider using the " + nameof(Load) + " method instead."); | ||
} | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// Load a transformer model and a data loader model from the stream. | ||
/// Load a transformer model and a data loader model from a file. | ||
/// </summary> | ||
/// <param name="stream">A readable, seekable stream to load from.</param> | ||
/// <param name="loader">The data loader from the model stream.</param> | ||
/// <returns>The transformer model from the model stream.</returns> | ||
public ITransformer LoadWithDataLoader(Stream stream, out IDataLoader<IMultiStreamSource> loader) | ||
/// <param name="filePath">Path to a file where the model should be read from.</param> | ||
/// <param name="loader">The data loader from the model stream. Note that if there is no data loader, | ||
/// this method will throw an exception. The scenario where no loader is stored in the stream should | ||
/// be handled instead using the <see cref="Load(Stream, out DataViewSchema)"/> method.</param> | ||
/// <returns>The transformer model from the model file.</returns> | ||
public ITransformer LoadWithDataLoader(string filePath, out IDataLoader<IMultiStreamSource> loader) | ||
{ | ||
_env.CheckValue(stream, nameof(stream)); | ||
_env.CheckNonEmpty(filePath, nameof(filePath)); | ||
|
||
loader = Load(stream); | ||
if (loader is CompositeDataLoader<IMultiStreamSource, ITransformer> composite) | ||
{ | ||
loader = composite.Loader; | ||
return composite.Transformer; | ||
} | ||
return new TransformerChain<ITransformer>(); | ||
using (var stream = File.OpenRead(filePath)) | ||
return LoadWithDataLoader(stream, out loader); | ||
} | ||
|
||
/// <summary> | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Major things on the top of my mind, the save/load methods are still inconsistent w.r.t. where they put the file or stream argument... here we see, it is the last argument, whereas during loading it is the first argument.
One possibility is we leave it as is, since in this case it is where the model is going (so maybe last is most clear), whereas during loading it is where the model is coming from.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I kind of like this possibility.
If we're going with the "intuitive" ordering, then maybe
loader
should come beforemodel
?In reply to: 267790407 [](ancestors = 267790407)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I considered it. Actually intuitive cuts two ways here, since this
Save
method interacts with intellisense in a way that I feel makes this arrangement most natural. But I don't insist on it.