Skip to content

Commit 2cf7ee4

Browse files
committed
Make IDataLoader<TSource> inherit from ICanSaveModel, and add methods to save/load it in ModelOperationsCatalog.
1 parent 7b08c00 commit 2cf7ee4

File tree

6 files changed

+135
-37
lines changed

6 files changed

+135
-37
lines changed

src/Microsoft.ML.Core/Data/IEstimator.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ internal bool TryFindColumn(string name, out Column column)
224224
/// The 'data loader' takes a certain kind of input and turns it into an <see cref="IDataView"/>.
225225
/// </summary>
226226
/// <typeparam name="TSource">The type of input the loader takes.</typeparam>
227-
public interface IDataLoader<in TSource>
227+
public interface IDataLoader<in TSource> : ICanSaveModel
228228
{
229229
/// <summary>
230230
/// Produce the data view from the specified input.

src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs

+49-33
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44

55
using System.IO;
66
using Microsoft.Data.DataView;
7-
using Microsoft.ML.Model;
7+
using Microsoft.ML;
8+
using Microsoft.ML.Data;
9+
10+
[assembly: LoadableClass(CompositeDataLoader<IMultiStreamSource, ITransformer>.Summary, typeof(CompositeDataLoader<IMultiStreamSource, ITransformer>), null, typeof(SignatureLoadModel),
11+
"Composite Loader", CompositeDataLoader<IMultiStreamSource, ITransformer>.LoaderSignature)]
812

913
namespace Microsoft.ML.Data
1014
{
@@ -15,6 +19,10 @@ namespace Microsoft.ML.Data
1519
public sealed class CompositeDataLoader<TSource, TLastTransformer> : IDataLoader<TSource>
1620
where TLastTransformer : class, ITransformer
1721
{
22+
private const string LoaderDirectory = "Loader";
23+
private const string LegacyLoaderDirectory = "Reader";
24+
private const string TransformerDirectory = TransformerChain.LoaderSignature;
25+
1826
/// <summary>
1927
/// The underlying data loader.
2028
/// </summary>
@@ -33,6 +41,24 @@ public CompositeDataLoader(IDataLoader<TSource> loader, TransformerChain<TLastTr
3341
Transformer = transformerChain ?? new TransformerChain<TLastTransformer>();
3442
}
3543

44+
private CompositeDataLoader(IHost host, ModelLoadContext ctx)
45+
{
46+
if (!ctx.LoadModelOrNull<IDataLoader<TSource>, SignatureLoadModel>(host, out Loader, LegacyLoaderDirectory))
47+
ctx.LoadModel<IDataLoader<TSource>, SignatureLoadModel>(host, out Loader, LoaderDirectory);
48+
ctx.LoadModel<TransformerChain<TLastTransformer>, SignatureLoadModel>(host, out Transformer, TransformerDirectory);
49+
}
50+
51+
private static CompositeDataLoader<TSource, TLastTransformer> Create(IHostEnvironment env, ModelLoadContext ctx)
52+
{
53+
Contracts.CheckValue(env, nameof(env));
54+
IHost h = env.Register(LoaderSignature);
55+
56+
h.CheckValue(ctx, nameof(ctx));
57+
ctx.CheckAtModel(GetVersionInfo());
58+
59+
return h.Apply("Loading Model", ch => new CompositeDataLoader<TSource, TLastTransformer>(h, ctx));
60+
}
61+
3662
/// <summary>
3763
/// Produce the data view from the specified input.
3864
/// Note that <see cref="IDataView"/>'s are lazy, so no actual loading happens here, just schema validation.
@@ -62,6 +88,16 @@ public CompositeDataLoader<TSource, TNewLast> AppendTransformer<TNewLast>(TNewLa
6288
return new CompositeDataLoader<TSource, TNewLast>(Loader, Transformer.Append(transformer));
6389
}
6490

91+
void ICanSaveModel.Save(ModelSaveContext ctx)
92+
{
93+
Contracts.CheckValue(ctx, nameof(ctx));
94+
ctx.CheckAtModel();
95+
ctx.SetVersionInfo(GetVersionInfo());
96+
97+
ctx.SaveModel(Loader, LoaderDirectory);
98+
ctx.SaveModel(Transformer, TransformerDirectory);
99+
}
100+
65101
/// <summary>
66102
/// Save the contents to a stream, as a "model file".
67103
/// </summary>
@@ -76,47 +112,27 @@ public void SaveTo(IHostEnvironment env, Stream outputStream)
76112
using (var rep = RepositoryWriter.CreateNew(outputStream, ch))
77113
{
78114
ch.Trace("Saving data loader");
79-
ModelSaveContext.SaveModel(rep, Loader, "Reader");
115+
ModelSaveContext.SaveModel(rep, Loader, LoaderDirectory);
80116

81117
ch.Trace("Saving transformer chain");
82-
ModelSaveContext.SaveModel(rep, Transformer, TransformerChain.LoaderSignature);
118+
ModelSaveContext.SaveModel(rep, Transformer, TransformerDirectory);
83119
rep.Commit();
84120
}
85121
}
86122
}
87-
}
88123

89-
/// <summary>
90-
/// Utility class to facilitate loading from a stream.
91-
/// </summary>
92-
[BestFriend]
93-
internal static class CompositeDataLoader
94-
{
95-
/// <summary>
96-
/// Save the contents to a stream, as a "model file".
97-
/// </summary>
98-
public static void SaveTo<TSource>(this IDataLoader<TSource> loader, IHostEnvironment env, Stream outputStream)
99-
=> new CompositeDataLoader<TSource, ITransformer>(loader).SaveTo(env, outputStream);
124+
internal const string Summary = "A loader that encapsulates a loader and a transformer chain.";
100125

101-
/// <summary>
102-
/// Load the pipeline from stream.
103-
/// </summary>
104-
public static CompositeDataLoader<IMultiStreamSource, ITransformer> LoadFrom(IHostEnvironment env, Stream stream)
126+
internal const string LoaderSignature = "CompositeLoader";
127+
private static VersionInfo GetVersionInfo()
105128
{
106-
Contracts.CheckValue(env, nameof(env));
107-
env.CheckValue(stream, nameof(stream));
108-
109-
env.Check(stream.CanRead && stream.CanSeek, "Need a readable and seekable stream to load");
110-
using (var rep = RepositoryReader.Open(stream, env))
111-
using (var ch = env.Start("Loading pipeline"))
112-
{
113-
ch.Trace("Loading data loader");
114-
ModelLoadContext.LoadModel<IDataLoader<IMultiStreamSource>, SignatureLoadModel>(env, out var loader, rep, "Reader");
115-
116-
ch.Trace("Loader transformer chain");
117-
ModelLoadContext.LoadModel<TransformerChain<ITransformer>, SignatureLoadModel>(env, out var transformerChain, rep, TransformerChain.LoaderSignature);
118-
return new CompositeDataLoader<IMultiStreamSource, ITransformer>(loader, transformerChain);
119-
}
129+
return new VersionInfo(
130+
modelSignature: "CMPSTLDR",
131+
verWrittenCur: 0x00010001, // Initial
132+
verReadableCur: 0x00010001,
133+
verWeCanReadBack: 0x00010001,
134+
loaderSignature: LoaderSignature,
135+
loaderAssemblyName: typeof(CompositeDataLoader<,>).Assembly.FullName);
120136
}
121137
}
122138
}

src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs

+4-2
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,22 @@
1212
using Microsoft.ML.CommandLine;
1313
using Microsoft.ML.Data;
1414
using Microsoft.ML.Internal.Utilities;
15-
using Microsoft.ML.Model;
1615

1716
[assembly: LoadableClass(TextLoader.Summary, typeof(ILegacyDataLoader), typeof(TextLoader), typeof(TextLoader.Options), typeof(SignatureDataLoader),
1817
"Text Loader", "TextLoader", "Text", DocName = "loader/TextLoader.md")]
1918

2019
[assembly: LoadableClass(TextLoader.Summary, typeof(ILegacyDataLoader), typeof(TextLoader), null, typeof(SignatureLoadDataLoader),
2120
"Text Loader", TextLoader.LoaderSignature)]
2221

22+
[assembly: LoadableClass(TextLoader.Summary, typeof(TextLoader), null, typeof(SignatureLoadModel),
23+
"Text Loader", TextLoader.LoaderSignature)]
24+
2325
namespace Microsoft.ML.Data
2426
{
2527
/// <summary>
2628
/// Loads a text file into an IDataView. Supports basic mapping from input columns to <see cref="IDataView"/> columns.
2729
/// </summary>
28-
public sealed partial class TextLoader : IDataLoader<IMultiStreamSource>, ICanSaveModel
30+
public sealed partial class TextLoader : IDataLoader<IMultiStreamSource>
2931
{
3032
/// <summary>
3133
/// Describes how an input column should be mapped to an <see cref="IDataView"/> column.

src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs

+3-1
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,9 @@ public static TransformerChain<ITransformer> LoadFrom(IHostEnvironment env, Stre
255255
{
256256
try
257257
{
258-
ModelLoadContext.LoadModel<TransformerChain<ITransformer>, SignatureLoadModel>(env, out var transformerChain, rep, LoaderSignature);
258+
ModelLoadContext.LoadModelOrNull<TransformerChain<ITransformer>, SignatureLoadModel>(env, out var transformerChain, rep, LoaderSignature);
259+
if (transformerChain == null)
260+
ModelLoadContext.LoadModel<TransformerChain<ITransformer>, SignatureLoadModel>(env, out transformerChain, rep, $@"Model\{LoaderSignature}");
259261
return transformerChain;
260262
}
261263
catch (FormatException ex)

src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs

+21
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,34 @@ protected SubCatalogBase(ModelOperationsCatalog owner)
4545
/// <param name="stream">A writeable, seekable stream to save to.</param>
4646
public void Save(ITransformer model, Stream stream) => model.SaveTo(Environment, stream);
4747

48+
public void Save<TSource>(IDataLoader<TSource> model, Stream stream)
49+
{
50+
using (var rep = RepositoryWriter.CreateNew(stream))
51+
{
52+
ModelSaveContext.SaveModel(rep, model, "Model");
53+
rep.Commit();
54+
}
55+
}
56+
57+
public void Save<TSource>(IDataLoader<TSource> loader, ITransformer model, Stream stream) =>
58+
Save(new CompositeDataLoader<TSource, ITransformer>(loader, new TransformerChain<ITransformer>(model)), stream);
59+
4860
/// <summary>
4961
/// Load the model from the stream.
5062
/// </summary>
5163
/// <param name="stream">A readable, seekable stream to load from.</param>
5264
/// <returns>The loaded model.</returns>
5365
public ITransformer Load(Stream stream) => TransformerChain.LoadFrom(Environment, stream);
5466

67+
public IDataLoader<IMultiStreamSource> LoadAsCompositeDataLoader(Stream stream)
68+
{
69+
using (var rep = RepositoryReader.Open(stream))
70+
{
71+
ModelLoadContext.LoadModel<IDataLoader<IMultiStreamSource>, SignatureLoadModel>(Environment, out var model, rep, "Model");
72+
return model;
73+
}
74+
}
75+
5576
/// <summary>
5677
/// The catalog of model explainability operations.
5778
/// </summary>

test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DeserializationTests.cs

+57
Original file line numberDiff line numberDiff line change
@@ -53,5 +53,62 @@ public void LoadModelAndExtractPredictor()
5353
as BinaryClassificationGamModelParameters;
5454
Assert.NotNull(gam);
5555
}
56+
57+
[Fact]
58+
public void SaveAndLoadModelWithLoader()
59+
{
60+
var ml = new MLContext(seed: 1, conc: 1);
61+
var file = new MultiFileSource(GetDataPath(TestDatasets.adult.trainFilename));
62+
var loader = ml.Data.CreateTextLoader<InputData>(hasHeader: true, dataSample: file);
63+
var data = loader.Load(file);
64+
65+
// Pipeline.
66+
var pipeline = ml.BinaryClassification.Trainers.GeneralizedAdditiveModels();
67+
68+
// Train.
69+
var model = pipeline.Fit(data);
70+
71+
// Save and reload.
72+
string modelPath = GetOutputPath(FullTestName + "-model.zip");
73+
using (var fs = File.Create(modelPath))
74+
ml.Model.Save(loader, model, fs);
75+
76+
IDataLoader<IMultiStreamSource> loadedModel;
77+
ITransformer loadedModelWithoutLoader;
78+
using (var fs = File.OpenRead(modelPath))
79+
{
80+
loadedModel = ml.Model.LoadAsCompositeDataLoader(fs);
81+
loadedModelWithoutLoader = ml.Model.Load(fs);
82+
}
83+
84+
// Without deserializing the loader from the model we lose the slot names.
85+
data = ml.Data.LoadFromEnumerable(new[] { new InputData() });
86+
data = loadedModelWithoutLoader.Transform(data);
87+
Assert.Null(data.Schema["Features"].Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.SlotNames));
88+
89+
data = loadedModel.Load(file);
90+
Assert.True(data.Schema["Features"].HasSlotNames(data.Schema["Features"].Type.GetValueCount()));
91+
VBuffer<ReadOnlyMemory<char>> slotNames = default;
92+
data.Schema["Features"].GetSlotNames(ref slotNames);
93+
var ageIndex = FindIndex(slotNames.GetValues(), "age");
94+
var transformer = (loadedModel as CompositeDataLoader<IMultiStreamSource, ITransformer>).Transformer.LastTransformer;
95+
var gamModel = ((transformer as BinaryPredictionTransformer<object>).Model
96+
as CalibratedModelParametersBase<object, ICalibrator>).SubModel
97+
as BinaryClassificationGamModelParameters;
98+
var ageBinUpperBounds = gamModel.GetBinUpperBounds(ageIndex);
99+
var ageBinEffects = gamModel.GetBinEffects(ageIndex);
100+
}
101+
102+
private int FindIndex(ReadOnlySpan<ReadOnlyMemory<char>> values, string slotName)
103+
{
104+
int index = 0;
105+
foreach (var value in values)
106+
{
107+
if (value.Span.SequenceEqual(slotName.AsSpan()))
108+
return index;
109+
index++;
110+
}
111+
return -1;
112+
}
56113
}
57114
}

0 commit comments

Comments
 (0)