Skip to content

Commit 665a366

Browse files
authored
Add save/load APIs for IDataLoader (#2858)
* Add save/load APIs for IDataLoader * Address some code review comments, add a non-generic base class for calibrated predictor * use the contravariance of ISingleFeaturePredictionTransformer instead of loading PredictionTransformer<object> from file * Add API for saving/loading input schema * Fix build after rebase * Add API to create PredictionEngine with input schema * Address code review comments * Unfriend Functional.Tests * Add CreatePredictionEngine API back to ModelOperationsCatalog * Address code review comments * Fix build * Fix F# tests * Remove duplicate CreatePredictionEngine API * Add test for creating an IDataView from a loaded schema * Fix build error after rebase * Add unit tests, and address some code review comments * Fix build after rebase * Code review comments
1 parent 3ad489e commit 665a366

File tree

52 files changed

+872
-301
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+872
-301
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/TensorFlow/TextClassification.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,13 @@ public static void Example()
6868
j.Features = features;
6969
};
7070

71-
var engine = mlContext.Transforms.Text.TokenizeIntoWords("TokenizedWords", "Sentiment_Text")
71+
var model = mlContext.Transforms.Text.TokenizeIntoWords("TokenizedWords", "Sentiment_Text")
7272
.Append(mlContext.Transforms.Conversion.MapValue(lookupMap, "Words", "Ids", new ColumnOptions[] { ("VariableLenghtFeatures", "TokenizedWords") }))
7373
.Append(mlContext.Transforms.CustomMapping(ResizeFeaturesAction, "Resize"))
7474
.Append(tensorFlowModel.ScoreTensorFlowModel(new[] { "Prediction/Softmax" }, new[] { "Features" }))
7575
.Append(mlContext.Transforms.CopyColumns(("Prediction", "Prediction/Softmax")))
76-
.Fit(dataView)
77-
.CreatePredictionEngine<IMDBSentiment, OutputScores>(mlContext);
76+
.Fit(dataView);
77+
var engine = mlContext.Model.CreatePredictionEngine<IMDBSentiment, OutputScores>(model);
7878

7979
// Predict with TensorFlow pipeline.
8080
var prediction = engine.Predict(data[0]);

src/Microsoft.ML.Core/Data/IEstimator.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ internal bool TryFindColumn(string name, out Column column)
224224
/// The 'data loader' takes a certain kind of input and turns it into an <see cref="IDataView"/>.
225225
/// </summary>
226226
/// <typeparam name="TSource">The type of input the loader takes.</typeparam>
227-
public interface IDataLoader<in TSource>
227+
public interface IDataLoader<in TSource> : ICanSaveModel
228228
{
229229
/// <summary>
230230
/// Produce the data view from the specified input.

src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs

Lines changed: 44 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using System.IO;
5+
using Microsoft.ML;
6+
using Microsoft.ML.Data;
67
using Microsoft.ML.Runtime;
78

9+
[assembly: LoadableClass(CompositeDataLoader<IMultiStreamSource, ITransformer>.Summary, typeof(CompositeDataLoader<IMultiStreamSource, ITransformer>), null, typeof(SignatureLoadModel),
10+
"Composite Loader", CompositeDataLoader<IMultiStreamSource, ITransformer>.LoaderSignature)]
11+
812
namespace Microsoft.ML.Data
913
{
1014
/// <summary>
@@ -14,6 +18,10 @@ namespace Microsoft.ML.Data
1418
public sealed class CompositeDataLoader<TSource, TLastTransformer> : IDataLoader<TSource>
1519
where TLastTransformer : class, ITransformer
1620
{
21+
internal const string TransformerDirectory = TransformerChain.LoaderSignature;
22+
private const string LoaderDirectory = "Loader";
23+
private const string LegacyLoaderDirectory = "Reader";
24+
1725
/// <summary>
1826
/// The underlying data loader.
1927
/// </summary>
@@ -32,6 +40,24 @@ public CompositeDataLoader(IDataLoader<TSource> loader, TransformerChain<TLastTr
3240
Transformer = transformerChain ?? new TransformerChain<TLastTransformer>();
3341
}
3442

43+
private CompositeDataLoader(IHost host, ModelLoadContext ctx)
44+
{
45+
if (!ctx.LoadModelOrNull<IDataLoader<TSource>, SignatureLoadModel>(host, out Loader, LegacyLoaderDirectory))
46+
ctx.LoadModel<IDataLoader<TSource>, SignatureLoadModel>(host, out Loader, LoaderDirectory);
47+
ctx.LoadModel<TransformerChain<TLastTransformer>, SignatureLoadModel>(host, out Transformer, TransformerDirectory);
48+
}
49+
50+
private static CompositeDataLoader<TSource, TLastTransformer> Create(IHostEnvironment env, ModelLoadContext ctx)
51+
{
52+
Contracts.CheckValue(env, nameof(env));
53+
IHost h = env.Register(LoaderSignature);
54+
55+
h.CheckValue(ctx, nameof(ctx));
56+
ctx.CheckAtModel(GetVersionInfo());
57+
58+
return h.Apply("Loading Model", ch => new CompositeDataLoader<TSource, TLastTransformer>(h, ctx));
59+
}
60+
3561
/// <summary>
3662
/// Produce the data view from the specified input.
3763
/// Note that <see cref="IDataView"/>'s are lazy, so no actual loading happens here, just schema validation.
@@ -61,61 +87,28 @@ public CompositeDataLoader<TSource, TNewLast> AppendTransformer<TNewLast>(TNewLa
6187
return new CompositeDataLoader<TSource, TNewLast>(Loader, Transformer.Append(transformer));
6288
}
6389

64-
/// <summary>
65-
/// Save the contents to a stream, as a "model file".
66-
/// </summary>
67-
public void SaveTo(IHostEnvironment env, Stream outputStream)
90+
void ICanSaveModel.Save(ModelSaveContext ctx)
6891
{
69-
Contracts.CheckValue(env, nameof(env));
70-
env.CheckValue(outputStream, nameof(outputStream));
71-
72-
env.Check(outputStream.CanWrite && outputStream.CanSeek, "Need a writable and seekable stream to save");
73-
using (var ch = env.Start("Saving pipeline"))
74-
{
75-
using (var rep = RepositoryWriter.CreateNew(outputStream, ch))
76-
{
77-
ch.Trace("Saving data loader");
78-
ModelSaveContext.SaveModel(rep, Loader, "Reader");
79-
80-
ch.Trace("Saving transformer chain");
81-
ModelSaveContext.SaveModel(rep, Transformer, TransformerChain.LoaderSignature);
82-
rep.Commit();
83-
}
84-
}
92+
Contracts.CheckValue(ctx, nameof(ctx));
93+
ctx.CheckAtModel();
94+
ctx.SetVersionInfo(GetVersionInfo());
95+
96+
ctx.SaveModel(Loader, LoaderDirectory);
97+
ctx.SaveModel(Transformer, TransformerDirectory);
8598
}
86-
}
8799

88-
/// <summary>
89-
/// Utility class to facilitate loading from a stream.
90-
/// </summary>
91-
[BestFriend]
92-
internal static class CompositeDataLoader
93-
{
94-
/// <summary>
95-
/// Save the contents to a stream, as a "model file".
96-
/// </summary>
97-
public static void SaveTo<TSource>(this IDataLoader<TSource> loader, IHostEnvironment env, Stream outputStream)
98-
=> new CompositeDataLoader<TSource, ITransformer>(loader).SaveTo(env, outputStream);
100+
internal const string Summary = "A model loader that encapsulates a data loader and a transformer chain.";
99101

100-
/// <summary>
101-
/// Load the pipeline from stream.
102-
/// </summary>
103-
public static CompositeDataLoader<IMultiStreamSource, ITransformer> LoadFrom(IHostEnvironment env, Stream stream)
102+
internal const string LoaderSignature = "CompositeLoader";
103+
private static VersionInfo GetVersionInfo()
104104
{
105-
Contracts.CheckValue(env, nameof(env));
106-
env.CheckValue(stream, nameof(stream));
107-
108-
env.Check(stream.CanRead && stream.CanSeek, "Need a readable and seekable stream to load");
109-
using (var rep = RepositoryReader.Open(stream, env))
110-
using (var ch = env.Start("Loading pipeline"))
111-
{
112-
ch.Trace("Loading data loader");
113-
ModelLoadContext.LoadModel<IDataLoader<IMultiStreamSource>, SignatureLoadModel>(env, out var loader, rep, "Reader");
114-
115-
ch.Trace("Loader transformer chain");
116-
ModelLoadContext.LoadModel<TransformerChain<ITransformer>, SignatureLoadModel>(env, out var transformerChain, rep, TransformerChain.LoaderSignature);
117-
return new CompositeDataLoader<IMultiStreamSource, ITransformer>(loader, transformerChain);
118-
}
105+
return new VersionInfo(
106+
modelSignature: "CMPSTLDR",
107+
verWrittenCur: 0x00010001, // Initial
108+
verReadableCur: 0x00010001,
109+
verWeCanReadBack: 0x00010001,
110+
loaderSignature: LoaderSignature,
111+
loaderAssemblyName: typeof(CompositeDataLoader<,>).Assembly.FullName);
119112
}
120113
}
121114
}

src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,14 @@ public IDataView LoadFromEnumerable<TRow>(IEnumerable<TRow> data, SchemaDefiniti
8282
return DataViewConstructionUtils.CreateFromEnumerable(_env, data, schemaDefinition);
8383
}
8484

85+
public IDataView LoadFromEnumerable<TRow>(IEnumerable<TRow> data, DataViewSchema schema)
86+
where TRow : class
87+
{
88+
_env.CheckValue(data, nameof(data));
89+
_env.CheckValue(schema, nameof(schema));
90+
return DataViewConstructionUtils.CreateFromEnumerable(_env, data, schema);
91+
}
92+
8593
/// <summary>
8694
/// Convert an <see cref="IDataView"/> into a strongly-typed <see cref="IEnumerable{TRow}"/>.
8795
/// </summary>

src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs

Lines changed: 41 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,15 @@
1919
[assembly: LoadableClass(TextLoader.Summary, typeof(ILegacyDataLoader), typeof(TextLoader), null, typeof(SignatureLoadDataLoader),
2020
"Text Loader", TextLoader.LoaderSignature)]
2121

22+
[assembly: LoadableClass(TextLoader.Summary, typeof(TextLoader), null, typeof(SignatureLoadModel),
23+
"Text Loader", TextLoader.LoaderSignature)]
24+
2225
namespace Microsoft.ML.Data
2326
{
2427
/// <summary>
2528
/// Loads a text file into an IDataView. Supports basic mapping from input columns to <see cref="IDataView"/> columns.
2629
/// </summary>
27-
public sealed partial class TextLoader : IDataLoader<IMultiStreamSource>, ICanSaveModel
30+
public sealed partial class TextLoader : IDataLoader<IMultiStreamSource>
2831
{
2932
/// <summary>
3033
/// Describes how an input column should be mapped to an <see cref="IDataView"/> column.
@@ -1189,31 +1192,31 @@ private char NormalizeSeparator(string sep)
11891192
{
11901193
switch (sep)
11911194
{
1192-
case "space":
1193-
case " ":
1194-
return ' ';
1195-
case "tab":
1196-
case "\t":
1197-
return '\t';
1198-
case "comma":
1199-
case ",":
1200-
return ',';
1201-
case "colon":
1202-
case ":":
1203-
_host.CheckUserArg((_flags & OptionFlags.AllowSparse) == 0, nameof(Options.Separator),
1204-
"When the separator is colon, turn off allowSparse");
1205-
return ':';
1206-
case "semicolon":
1207-
case ";":
1208-
return ';';
1209-
case "bar":
1210-
case "|":
1211-
return '|';
1212-
default:
1213-
char ch = sep[0];
1214-
if (sep.Length != 1 || ch < ' ' || '0' <= ch && ch <= '9' || ch == '"')
1215-
throw _host.ExceptUserArg(nameof(Options.Separator), "Illegal separator: '{0}'", sep);
1216-
return sep[0];
1195+
case "space":
1196+
case " ":
1197+
return ' ';
1198+
case "tab":
1199+
case "\t":
1200+
return '\t';
1201+
case "comma":
1202+
case ",":
1203+
return ',';
1204+
case "colon":
1205+
case ":":
1206+
_host.CheckUserArg((_flags & OptionFlags.AllowSparse) == 0, nameof(Options.Separator),
1207+
"When the separator is colon, turn off allowSparse");
1208+
return ':';
1209+
case "semicolon":
1210+
case ";":
1211+
return ';';
1212+
case "bar":
1213+
case "|":
1214+
return '|';
1215+
default:
1216+
char ch = sep[0];
1217+
if (sep.Length != 1 || ch < ' ' || '0' <= ch && ch <= '9' || ch == '"')
1218+
throw _host.ExceptUserArg(nameof(Options.Separator), "Illegal separator: '{0}'", sep);
1219+
return sep[0];
12171220
}
12181221
}
12191222

@@ -1310,7 +1313,7 @@ private static bool TryParseSchema(IHost host, IMultiStreamSource files,
13101313
error = false;
13111314
options = optionsNew;
13121315

1313-
LDone:
1316+
LDone:
13141317
return !error;
13151318
}
13161319
}
@@ -1470,20 +1473,20 @@ internal static TextLoader CreateTextLoader<TInput>(IHostEnvironment host,
14701473
InternalDataKind dk;
14711474
switch (memberInfo)
14721475
{
1473-
case FieldInfo field:
1474-
if (!InternalDataKindExtensions.TryGetDataKind(field.FieldType.IsArray ? field.FieldType.GetElementType() : field.FieldType, out dk))
1475-
throw Contracts.Except($"Field {memberInfo.Name} is of unsupported type.");
1476+
case FieldInfo field:
1477+
if (!InternalDataKindExtensions.TryGetDataKind(field.FieldType.IsArray ? field.FieldType.GetElementType() : field.FieldType, out dk))
1478+
throw Contracts.Except($"Field {memberInfo.Name} is of unsupported type.");
14761479

1477-
break;
1480+
break;
14781481

1479-
case PropertyInfo property:
1480-
if (!InternalDataKindExtensions.TryGetDataKind(property.PropertyType.IsArray ? property.PropertyType.GetElementType() : property.PropertyType, out dk))
1481-
throw Contracts.Except($"Property {memberInfo.Name} is of unsupported type.");
1482-
break;
1482+
case PropertyInfo property:
1483+
if (!InternalDataKindExtensions.TryGetDataKind(property.PropertyType.IsArray ? property.PropertyType.GetElementType() : property.PropertyType, out dk))
1484+
throw Contracts.Except($"Property {memberInfo.Name} is of unsupported type.");
1485+
break;
14831486

1484-
default:
1485-
Contracts.Assert(false);
1486-
throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo");
1487+
default:
1488+
Contracts.Assert(false);
1489+
throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo");
14871490
}
14881491

14891492
column.Type = dk;

0 commit comments

Comments
 (0)