Skip to content

Remove model saving/loading inconsistencies #3044

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions docs/code/MlNetCookBook.md
Original file line number Diff line number Diff line change
Expand Up @@ -383,19 +383,18 @@ var metrics = mlContext.Regression.Evaluate(model.Transform(testData), labelColu

Assuming that the model metrics look good to you, it's time to 'operationalize' the model. This is where ML.NET really shines: the `model` object you just built is ready for immediate consumption, it will apply all the same steps that it has 'learned' during training, and it can be persisted and reused in different environments.

Here's what you do to save the model to a file, and reload it (potentially in a different context).
Here's what you do to save the model as well as its input schema to a file, and reload it (potentially in a different context).

```csharp
using (var stream = File.Create(modelPath))
{
mlContext.Model.Save(model, stream);
}
// Saving and loading happens to transformers. We save the input schema with this model.
mlContext.Model.Save(model, trainData.Schema, modelPath);

// Potentially, the lines below can be in a different process altogether.
ITransformer loadedModel;
using (var stream = File.OpenRead(modelPath))
loadedModel = mlContext.Model.Load(stream);
// When you load the model, it's a non-specific ITransformer. We also recover
// the original schema.
ITransformer loadedModel = mlContext.Model.Load(modelPath, out var schema);
```

## How do I use the model to make one prediction?

Since any ML.NET model is a transformer, you can of course use `model.Transform` to apply the model to the 'data view' and obtain predictions this way.
Expand Down Expand Up @@ -1018,7 +1017,5 @@ using (var fs = File.Create(modelPath))
newContext.ComponentCatalog.RegisterAssembly(typeof(CustomMappings).Assembly);

// Now we can load the model.
ITransformer loadedModel;
using (var fs = File.OpenRead(modelPath))
loadedModel = newContext.Model.Load(fs);
ITransformer loadedModel = newContext.Model.Load(modelPath, out var schema);
```
11 changes: 3 additions & 8 deletions docs/code/experimental/MlNetCookBookStaticApi.md
Original file line number Diff line number Diff line change
Expand Up @@ -396,18 +396,13 @@ This is where ML.NET really shines: the `model` object you just built is ready f
Here's what you do to save the model to a file, and reload it (potentially in a different context).

```csharp
using (var stream = File.Create(modelPath))
{
// Saving and loading happens to 'dynamic' models, so the static typing is lost in the process.
mlContext.Model.Save(model.AsDynamic, stream);
}
// Saving and loading happens to 'dynamic' models, so the static typing is lost in the process.
mlContext.Model.Save(model.AsDynamic, trainData.AsDynamic.Schema, modelPath);

// Potentially, the lines below can be in a different process altogether.

// When you load the model, it's a 'dynamic' transformer.
ITransformer loadedModel;
using (var stream = File.OpenRead(modelPath))
loadedModel = mlContext.Model.Load(stream);
ITransformer loadedModel = mlContext.Model.Load(modelPath, out var schema);
```

## How do I use the model to make one prediction?
Expand Down
172 changes: 109 additions & 63 deletions src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,63 +29,60 @@ internal ModelOperationsCatalog(IHostEnvironment env)
}

/// <summary>
/// Save the model to the stream.
/// Save a transformer model and the loader used to create its input data to the stream.
/// </summary>
/// <param name="model">The trained model to be saved.</param>
/// <param name="model">The trained model to be saved. Note that this can be <see langword="null"/>, as a shorthand
/// for an empty transformer chain. Upon loading with <see cref="LoadWithDataLoader(Stream, out IDataLoader{IMultiStreamSource})"/>
/// the returned value will be an empty <see cref="TransformerChain{TLastTransformer}"/>.</param>
/// <param name="loader">The loader that was used to create data to train the model.</param>
/// <param name="stream">A writeable, seekable stream to save to.</param>
public void Save<TSource>(IDataLoader<TSource> model, Stream stream)
public void Save<TSource>(ITransformer model, IDataLoader<TSource> loader, Stream stream)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Major things on the top of my mind, the save/load methods are still inconsistent w.r.t. where they put the file or stream argument... here we see, it is the last argument, whereas during loading it is the first argument.

One possibility is we leave it as is, since in this case it is where the model is going (so maybe last is most clear), whereas during loading it is where the model is coming from.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kind of like this possibility.
If we're going with the "intuitive" ordering, then maybe loader should come before model?


In reply to: 267790407 [](ancestors = 267790407)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I considered it. Actually intuitive cuts two ways here, since this Save method interacts with intellisense in a way that I feel makes this arrangement most natural. But I don't insist on it.

{
_env.CheckValue(model, nameof(model));
_env.CheckValue(loader, nameof(loader));
_env.CheckValueOrNull(model);
_env.CheckValue(stream, nameof(stream));

// For the sake of consistency of this API specifically, when called upon we save any transformer
// in a single element transformer chain.
var chainedModel = model == null ? null : new TransformerChain<ITransformer>(model);
var compositeLoader = new CompositeDataLoader<TSource, ITransformer>(loader, chainedModel);

using (var rep = RepositoryWriter.CreateNew(stream))
{
ModelSaveContext.SaveModel(rep, model, null);
ModelSaveContext.SaveModel(rep, compositeLoader, null);
rep.Commit();
}
}

/// <summary>
/// Save the model to the file.
/// </summary>
/// <param name="model">The trained model to be saved.</param>
/// <param name="filePath">Path where model should be saved.</param>
public void Save<TSource>(IDataLoader<TSource> model, string filePath)
{
using (var stream = File.Create(filePath))
Save(model, stream);
}

/// <summary>
/// Save a transformer model and the loader used to create its input data to the stream.
/// </summary>
/// <param name="loader">The loader that was used to create data to train the model</param>
/// <param name="model">The trained model to be saved</param>
/// <param name="stream">A writeable, seekable stream to save to.</param>
public void Save<TSource>(IDataLoader<TSource> loader, ITransformer model, Stream stream) =>
Save(new CompositeDataLoader<TSource, ITransformer>(loader, new TransformerChain<ITransformer>(model)), stream);

/// <summary>
/// Save a transformer model and the loader used to create its input data to the file.
/// </summary>
/// <param name="loader">The loader that was used to create data to train the model</param>
/// <param name="model">The trained model to be saved</param>
/// <param name="model">The trained model to be saved. Note that this can be <see langword="null"/>, as a shorthand
/// for an empty transformer chain. Upon loading with <see cref="LoadWithDataLoader(Stream, out IDataLoader{IMultiStreamSource})"/>
/// the returned value will be an empty <see cref="TransformerChain{TLastTransformer}"/>.</param>
/// <param name="loader">The loader that was used to create data to train the model.</param>
/// <param name="filePath">Path where model should be saved.</param>
public void Save<TSource>(IDataLoader<TSource> loader, ITransformer model, string filePath)
public void Save<TSource>(ITransformer model, IDataLoader<TSource> loader, string filePath)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also curious about whether we agree that what is effectively the descriptor of input to the model should come after the model itself. The API had been inconsistent, in that in one save method we had loader/model (so input to model came first), whereas in another we had model/schema (so input to model came second).

In all cases I feel like the load methods should always just return ITransform, so maybe just always putting ITransform
front-and-center for saving was probably the best approach, so I resolved the inconsistency that way, but I don't insist on it.

I kept it as a separate commit in case we wanted to go the other way, maybe... so in that case, the schema.

{
_env.CheckValueOrNull(model);
_env.CheckValue(loader, nameof(loader));
_env.CheckNonEmpty(filePath, nameof(filePath));

using (var stream = File.Create(filePath))
Save(loader, model, stream);
Save(model, loader, stream);
}

/// <summary>
/// Save a transformer model and the schema of the data that was used to train it to the stream.
/// </summary>
/// <param name="model">The trained model to be saved.</param>
/// <param name="inputSchema">The schema of the input to the transformer. This can be null.</param>
/// <param name="model">The trained model to be saved. Note that this can be <see langword="null"/>, as a shorthand
/// for an empty transformer chain. Upon loading with <see cref="Load(Stream, out DataViewSchema)"/> the returned value will
/// be an empty <see cref="TransformerChain{TLastTransformer}"/>.</param>
/// <param name="inputSchema">The schema of the input to the transformer. This can be <see langword="null"/>.</param>
/// <param name="stream">A writeable, seekable stream to save to.</param>
public void Save(ITransformer model, DataViewSchema inputSchema, Stream stream)
{
_env.CheckValue(model, nameof(model));
_env.CheckValueOrNull(model);
_env.CheckValueOrNull(inputSchema);
_env.CheckValue(stream, nameof(stream));

Expand All @@ -100,11 +97,17 @@ public void Save(ITransformer model, DataViewSchema inputSchema, Stream stream)
/// <summary>
/// Save a transformer model and the schema of the data that was used to train it to the file.
/// </summary>
/// <param name="model">The trained model to be saved.</param>
/// <param name="inputSchema">The schema of the input to the transformer. This can be null.</param>
/// <param name="model">The trained model to be saved. Note that this can be <see langword="null"/>, as a shorthand
/// for an empty transformer chain. Upon loading with <see cref="Load(Stream, out DataViewSchema)"/> the returned value will
/// be an empty <see cref="TransformerChain{TLastTransformer}"/>.</param>
/// <param name="inputSchema">The schema of the input to the transformer. This can be <see langword="null"/>.</param>
/// <param name="filePath">Path where model should be saved.</param>
public void Save(ITransformer model, DataViewSchema inputSchema, string filePath)
{
_env.CheckValueOrNull(model);
_env.CheckValueOrNull(inputSchema);
_env.CheckNonEmpty(filePath, nameof(filePath));

using (var stream = File.Create(filePath))
Save(model, inputSchema, stream);
}
Expand All @@ -126,11 +129,11 @@ private void SaveInputSchema(DataViewSchema inputSchema, RepositoryWriter rep)
}

/// <summary>
/// Load the model and its input schema from the stream.
/// Load the model and its input schema from a stream.
/// </summary>
/// <param name="stream">A readable, seekable stream to load from.</param>
/// <param name="inputSchema">Will contain the input schema for the model. If the model was saved using older APIs
/// it may not contain an input schema, in this case <paramref name="inputSchema"/> will be null.</param>
/// <param name="inputSchema">Will contain the input schema for the model. If the model was saved without
/// any description of the input, there will be no input schema. In this case this can be <see langword="null"/>.</param>
/// <returns>The loaded model.</returns>
public ITransformer Load(Stream stream, out DataViewSchema inputSchema)
{
Expand Down Expand Up @@ -171,57 +174,100 @@ public ITransformer Load(Stream stream, out DataViewSchema inputSchema)
throw _env.Except(ex, "Could not load legacy format model");
}
}
if (dataLoader is CompositeDataLoader<IMultiStreamSource, ITransformer> composite)
{
inputSchema = composite.Loader.GetOutputSchema();
return composite.Transformer;
}
var transformer = DecomposeLoader(ref dataLoader);
inputSchema = dataLoader.GetOutputSchema();
return new TransformerChain<ITransformer>();
return transformer;
}
}

/// <summary>
/// Load the model and its input schema from a file.
/// </summary>
/// <param name="filePath">Path to a file where the model should be read from.</param>
/// <param name="inputSchema">Will contain the input schema for the model. If the model was saved without
/// any description of the input, there will be no input schema. In this case this can be <see langword="null"/>.</param>
/// <returns>The loaded model.</returns>
public ITransformer Load(string filePath, out DataViewSchema inputSchema)
{
_env.CheckNonEmpty(filePath, nameof(filePath));

using (var stream = File.OpenRead(filePath))
return Load(stream, out inputSchema);
}

/// <summary>
/// Given a loader, test try to "decompose" it into a source loader, and its transform if any.
/// If necessary an empty chain will be created to stand in for the trivial transformation; it
/// should never return <see langword="null"/>.
/// </summary>
private ITransformer DecomposeLoader(ref IDataLoader<IMultiStreamSource> loader)
{
_env.AssertValue(loader);

if (loader is CompositeDataLoader<IMultiStreamSource, ITransformer> composite)
{
loader = composite.Loader;
var chain = composite.Transformer;
// The save method corresponding to this load method encapsulates the input ITransformer
// into a single-element transformer chain. If it is that sort, we guess that it is in fact
// that sort, and so return it.
var accessor = (ITransformerChainAccessor)chain;
if (accessor.Transformers.Length == 1)
return accessor.Transformers[0];
// If it is some other length than 1 due to, say, some legacy model saving, just return that
// chain. Using the above API this is not possible, since the chain saved will always be of length
// one, but older APIs behaved differently so we should retain flexibility with those schemes.
// (Those schemes are BTW by no means incorrect, they just aren't what the API in this particular
// class will specifically do.)
return chain;
}
// Maybe we have no transformer stored. Rather than return null, we prefer to return the
// empty "trivial" transformer chain.
return new TransformerChain<ITransformer>();
}

/// <summary>
/// Load the model and its input schema from the stream.
/// Load a transformer model and a data loader model from a stream.
/// </summary>
/// <param name="stream">A readable, seekable stream to load from.</param>
/// <returns>A model of type <see cref="CompositeDataLoader{IMultiStreamSource, ITransformer}"/> containing the loader
/// and the transformer chain.</returns>
public IDataLoader<IMultiStreamSource> Load(Stream stream)
/// <param name="loader">The data loader from the model stream. Note that if there is no data loader,
/// this method will throw an exception. The scenario where no loader is stored in the stream should
/// be handled instead using the <see cref="Load(Stream, out DataViewSchema)"/> method.</param>
/// <returns>The transformer model from the model stream.</returns>
public ITransformer LoadWithDataLoader(Stream stream, out IDataLoader<IMultiStreamSource> loader)
{
_env.CheckValue(stream, nameof(stream));

using (var rep = RepositoryReader.Open(stream))
{
try
{
ModelLoadContext.LoadModel<IDataLoader<IMultiStreamSource>, SignatureLoadModel>(_env, out var model, rep, null);
return model;
ModelLoadContext.LoadModel<IDataLoader<IMultiStreamSource>, SignatureLoadModel>(_env, out loader, rep, null);
return DecomposeLoader(ref loader);
}
catch (Exception ex)
{
throw _env.Except(ex, "Model does not contain an IDataLoader");
throw _env.Except(ex, "Model does not contain an " + nameof(IDataLoader<IMultiStreamSource>) +
". Perhaps this was saved with an " + nameof(DataViewSchema) + ", or even no information on its input at all. " +
"Consider using the " + nameof(Load) + " method instead.");
}
}
}

/// <summary>
/// Load a transformer model and a data loader model from the stream.
/// Load a transformer model and a data loader model from a file.
/// </summary>
/// <param name="stream">A readable, seekable stream to load from.</param>
/// <param name="loader">The data loader from the model stream.</param>
/// <returns>The transformer model from the model stream.</returns>
public ITransformer LoadWithDataLoader(Stream stream, out IDataLoader<IMultiStreamSource> loader)
/// <param name="filePath">Path to a file where the model should be read from.</param>
/// <param name="loader">The data loader from the model stream. Note that if there is no data loader,
/// this method will throw an exception. The scenario where no loader is stored in the stream should
/// be handled instead using the <see cref="Load(Stream, out DataViewSchema)"/> method.</param>
/// <returns>The transformer model from the model file.</returns>
public ITransformer LoadWithDataLoader(string filePath, out IDataLoader<IMultiStreamSource> loader)
{
_env.CheckValue(stream, nameof(stream));
_env.CheckNonEmpty(filePath, nameof(filePath));

loader = Load(stream);
if (loader is CompositeDataLoader<IMultiStreamSource, ITransformer> composite)
{
loader = composite.Loader;
return composite.Transformer;
}
return new TransformerChain<ITransformer>();
using (var stream = File.OpenRead(filePath))
return LoadWithDataLoader(stream, out loader);
}

/// <summary>
Expand Down
Loading