Skip to content

Added an extension method for saving statically typed model (#1286) #2924

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/code/experimental/MlNetCookBookStaticApi.md
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ Here's what you do to save the model to a file, and reload it (potentially in a

```csharp
// Saving and loading happens to 'dynamic' models, so the static typing is lost in the process.
mlContext.Model.Save(model.AsDynamic, trainData.AsDynamic.Schema, modelPath);
mlContext.Model.Save(model, trainData, modelPath);

// Potentially, the lines below can be in a different process altogether.

Expand Down
42 changes: 42 additions & 0 deletions src/Microsoft.ML.StaticPipe/ModelOperationsCatalogExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.IO;
using Microsoft.ML.Data;

namespace Microsoft.ML.StaticPipe
{
public static class ModelOperationsCatalogExtensions
{
/// <summary>
/// Save statically typed model to the stream.
/// </summary>
/// <param name="catalog">The model explainability operations catalog.</param>
/// <param name="model">The trained model to be saved. Note that this can be <see langword="null"/>, as a shorthand
/// for an empty transformer chain. Upon loading with <see cref="ML.ModelOperationsCatalog.Load(Stream, out DataViewSchema)"/> the returned value will
/// be an empty <see cref="TransformerChain{TLastTransformer}"/>.</param>
/// <param name="dataView">The data view with the schema of the input to the transformer. This can be <see langword="null"/>.</param>
/// <param name="stream">A writeable, seekable stream to save to.</param>
public static void Save<TInShape, TOutShape, TTransformer>(this ML.ModelOperationsCatalog catalog, Transformer<TInShape, TOutShape, TTransformer> model, DataView<TInShape> dataView, Stream stream)
where TTransformer : class, ITransformer
{
catalog.Save(model?.AsDynamic, dataView?.AsDynamic.Schema, stream);
}

/// <summary>
/// Save statically typed model to the stream.
/// </summary>
/// <param name="catalog">The model explainability operations catalog.</param>
/// <param name="model">The trained model to be saved. Note that this can be <see langword="null"/>, as a shorthand
/// for an empty transformer chain. Upon loading with <see cref="ML.ModelOperationsCatalog.Load(Stream, out DataViewSchema)"/> the returned value will
/// be an empty <see cref="TransformerChain{TLastTransformer}"/>.</param>
/// <param name="dataView">The data view with the schema of the input to the transformer. This can be <see langword="null"/>.</param>
/// <param name="filePath">Path where model should be saved.</param>
public static void Save<TInShape, TOutShape, TTransformer>(this ML.ModelOperationsCatalog catalog, Transformer<TInShape, TOutShape, TTransformer> model, DataView<TInShape> dataView, string filePath)
where TTransformer : class, ITransformer
{
catalog.Save(model?.AsDynamic, dataView?.AsDynamic.Schema, filePath);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ private void TrainRegression(string trainDataPath, string testDataPath, string m
var metrics = mlContext.Regression.Evaluate(model.Transform(testData), label: r => r.Target, score: r => r.Prediction);

// Saving and loading happens to 'dynamic' models, so the static typing is lost in the process.
mlContext.Model.Save(model.AsDynamic, trainData.AsDynamic.Schema, modelPath);
mlContext.Model.Save(model, trainData, modelPath);

// Potentially, the lines below can be in a different process altogether.

Expand Down