Skip to content

ITransformer derives from ICanSaveModel and explicit implementation for ICanSaveModel #2431

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 9, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/Microsoft.ML.Core/Data/IEstimator.cs
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@
using System.Linq;
using Microsoft.Data.DataView;
using Microsoft.ML.Data;
using Microsoft.ML.Model;

namespace Microsoft.ML.Core.Data
{
@@ -263,7 +264,7 @@ public interface IDataReaderEstimator<in TSource, out TReader>
/// The transformer is a component that transforms data.
/// It also supports 'schema propagation' to answer the question of 'how will the data with this schema look, after you transform it?'.
/// </summary>
public interface ITransformer
public interface ITransformer : ICanSaveModel
{
/// <summary>
/// Schema propagation for transformers.
Original file line number Diff line number Diff line change
@@ -10,7 +10,8 @@

namespace Microsoft.ML.Model
{
[StructLayout(LayoutKind.Explicit, Size = ModelHeader.Size)]
[BestFriend]
[StructLayout(LayoutKind.Explicit, Size = Size)]
internal struct ModelHeader
{
/// <summary>
Original file line number Diff line number Diff line change
@@ -10,6 +10,12 @@

namespace Microsoft.ML.Model
{
/// <summary>
/// Signature for a repository based model loader. This is the dual of <see cref="ICanSaveModel"/>.
/// </summary>
[BestFriend]
internal delegate void SignatureLoadModel(ModelLoadContext ctx);

public sealed partial class ModelLoadContext : IDisposable
{
public const string ModelStreamName = "Model.key";
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -10,13 +10,11 @@

namespace Microsoft.ML.Model
{
/// <summary>
/// Signature for a repository based model loader. This is the dual of ICanSaveModel.
/// </summary>
public delegate void SignatureLoadModel(ModelLoadContext ctx);

/// <summary>
/// For saving a model into a repository.
/// Classes implementing <see cref="ICanSaveModel"/> should do an explicit implementation of <see cref="Save(ModelSaveContext)"/>.
/// Classes inheriting <see cref="ICanSaveModel"/> from a base class should overwrite the function invoked by <see cref="Save(ModelSaveContext)"/>
/// in that base class, if there is one.
/// </summary>
public interface ICanSaveModel
{
@@ -293,6 +291,8 @@ protected Entry AddEntry(string pathEnt, Stream stream)

public sealed class RepositoryWriter : Repository
{
private const string DirTrainingInfo = "TrainingInfo";

private ZipArchive _archive;
private Queue<KeyValuePair<string, Stream>> _closed;

@@ -301,7 +301,7 @@ public static RepositoryWriter CreateNew(Stream stream, IExceptionContext ectx =
Contracts.CheckValueOrNull(ectx);
ectx.CheckValue(stream, nameof(stream));
var rep = new RepositoryWriter(stream, ectx, useFileSystem);
using (var ent = rep.CreateEntry(ModelFileUtils.DirTrainingInfo, "Version.txt"))
using (var ent = rep.CreateEntry(DirTrainingInfo, "Version.txt"))
using (var writer = Utils.OpenWriter(ent.Stream))
writer.WriteLine(typeof(RepositoryWriter).Assembly.GetName().Version);
return rep;
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs
Original file line number Diff line number Diff line change
@@ -942,7 +942,7 @@ private static Stream OpenStream(string filename)
return OpenStream(files);
}

public void Save(ModelSaveContext ctx)
void ICanSaveModel.Save(ModelSaveContext ctx)
{
_host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs
Original file line number Diff line number Diff line change
@@ -502,7 +502,7 @@ private static IDataLoader LoadTransforms(ModelLoadContext ctx, IDataLoader srcL
});
}

public void Save(ModelSaveContext ctx)
void ICanSaveModel.Save(ModelSaveContext ctx)
{
_host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
6 changes: 3 additions & 3 deletions src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs
Original file line number Diff line number Diff line change
@@ -834,7 +834,7 @@ public Bindings(ModelLoadContext ctx, TextLoader parent)
OutputSchema = ComputeOutputSchema();
}

public void Save(ModelSaveContext ctx)
internal void Save(ModelSaveContext ctx)
{
Contracts.AssertValue(ctx);

@@ -1283,7 +1283,7 @@ internal static IDataLoader Create(IHostEnvironment env, Arguments args, IMultiS
internal static IDataView ReadFile(IHostEnvironment env, Arguments args, IMultiStreamSource fileSource)
=> new TextLoader(env, args, fileSource).Read(fileSource);

public void Save(ModelSaveContext ctx)
void ICanSaveModel.Save(ModelSaveContext ctx)
{
_host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
@@ -1420,7 +1420,7 @@ public RowCursor[] GetRowCursorSet(IEnumerable<Schema.Column> columnsNeeded, int
return Cursor.CreateSet(_reader, _files, active, n);
}

public void Save(ModelSaveContext ctx) => _reader.Save(ctx);
void ICanSaveModel.Save(ModelSaveContext ctx) => ((ICanSaveModel)_reader).Save(ctx);
}
}
}
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs
Original file line number Diff line number Diff line change
@@ -18,7 +18,7 @@ namespace Microsoft.ML.Data
{
// REVIEW: this class is public, as long as the Wrappers.cs in tests still rely on it.
// It needs to become internal.
public sealed class TransformWrapper : ITransformer, ICanSaveModel
public sealed class TransformWrapper : ITransformer
{
public const string LoaderSignature = "TransformWrapper";
private const string TransformDirTemplate = "Step_{0:000}";
@@ -46,7 +46,7 @@ public Schema GetOutputSchema(Schema inputSchema)
return output.Schema;
}

public void Save(ModelSaveContext ctx)
void ICanSaveModel.Save(ModelSaveContext ctx)
{
if (!_allowSave)
throw _host.Except("Saving is not permitted.");
6 changes: 3 additions & 3 deletions src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs
Original file line number Diff line number Diff line change
@@ -51,7 +51,7 @@ internal interface ITransformerChainAccessor
/// A chain of transformers (possibly empty) that end with a <typeparamref name="TLastTransformer"/>.
/// For an empty chain, <typeparamref name="TLastTransformer"/> is always <see cref="ITransformer"/>.
/// </summary>
public sealed class TransformerChain<TLastTransformer> : ITransformer, ICanSaveModel, IEnumerable<ITransformer>, ITransformerChainAccessor
public sealed class TransformerChain<TLastTransformer> : ITransformer, IEnumerable<ITransformer>, ITransformerChainAccessor
where TLastTransformer : class, ITransformer
{
private readonly ITransformer[] _transformers;
@@ -165,7 +165,7 @@ public TransformerChain<TNewLast> Append<TNewLast>(TNewLast transformer, Transfo
return new TransformerChain<TNewLast>(_transformers.AppendElement(transformer), _scopes.AppendElement(scope));
}

public void Save(ModelSaveContext ctx)
void ICanSaveModel.Save(ModelSaveContext ctx)
{
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());
@@ -181,7 +181,7 @@ public void Save(ModelSaveContext ctx)
}

/// <summary>
/// The loading constructor of transformer chain. Reverse of <see cref="Save(ModelSaveContext)"/>.
/// The loading constructor of transformer chain. Reverse of <see cref="ICanSaveModel.Save"/>.
/// </summary>
internal TransformerChain(IHostEnvironment env, ModelLoadContext ctx)
{
Original file line number Diff line number Diff line change
@@ -513,7 +513,7 @@ public static TransposeLoader Create(IHostEnvironment env, ModelLoadContext ctx,
});
}

public void Save(ModelSaveContext ctx)
void ICanSaveModel.Save(ModelSaveContext ctx)
{
_host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/DataView/LambdaColumnMapper.cs
Original file line number Diff line number Diff line change
@@ -140,7 +140,7 @@ public Impl(IHostEnvironment env, string name, IDataView input, OneToOneColumn c
Metadata.Seal();
}

public override void Save(ModelSaveContext ctx)
private protected override void SaveModel(ModelSaveContext ctx)
{
Host.Assert(false, "Shouldn't serialize this!");
throw Host.ExceptNotSupp("Shouldn't serialize this");
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/DataView/LambdaFilter.cs
Original file line number Diff line number Diff line change
@@ -96,7 +96,7 @@ public Impl(IHostEnvironment env, string name, IDataView input,
_conv = conv;
}

public override void Save(ModelSaveContext ctx)
private protected override void SaveModel(ModelSaveContext ctx)
{
Host.Assert(false, "Shouldn't serialize this!");
throw Host.ExceptNotSupp("Shouldn't serialize this");
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs
Original file line number Diff line number Diff line change
@@ -131,7 +131,7 @@ public static RowToRowMapperTransform Create(IHostEnvironment env, ModelLoadCont
return h.Apply("Loading Model", ch => new RowToRowMapperTransform(h, ctx, input));
}

public override void Save(ModelSaveContext ctx)
private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
Original file line number Diff line number Diff line change
@@ -223,7 +223,7 @@ public static ChooseColumnsByIndexTransform Create(IHostEnvironment env, ModelLo
return h.Apply("Loading Model", ch => new ChooseColumnsByIndexTransform(h, ctx, input));
}

public override void Save(ModelSaveContext ctx)
private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs
Original file line number Diff line number Diff line change
@@ -948,7 +948,7 @@ public static BinaryPerInstanceEvaluator Create(IHostEnvironment env, ModelLoadC
return new BinaryPerInstanceEvaluator(env, ctx, schema);
}

public override void Save(ModelSaveContext ctx)
private protected override void SaveModel(ModelSaveContext ctx)
{
Contracts.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
@@ -960,7 +960,7 @@ public override void Save(ModelSaveContext ctx)
// float: _threshold
// byte: _useRaw

base.Save(ctx);
base.SaveModel(ctx);
ctx.SaveStringOrNull(_probCol);
Contracts.Assert(FloatUtils.IsFinite(_threshold));
ctx.Writer.Write(_threshold);
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs
Original file line number Diff line number Diff line change
@@ -628,13 +628,13 @@ public static ClusteringPerInstanceEvaluator Create(IHostEnvironment env, ModelL
return new ClusteringPerInstanceEvaluator(env, ctx, schema);
}

public override void Save(ModelSaveContext ctx)
private protected override void SaveModel(ModelSaveContext ctx)
{
// *** Binary format **
// base
// int: number of clusters

base.Save(ctx);
base.SaveModel(ctx);
Host.Assert(_numClusters > 0);
ctx.Writer.Write(_numClusters);
}
7 changes: 6 additions & 1 deletion src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs
Original file line number Diff line number Diff line change
@@ -510,7 +510,12 @@ protected PerInstanceEvaluatorBase(IHostEnvironment env, ModelLoadContext ctx,
throw Host.ExceptSchemaMismatch(nameof(schema), "score", ScoreCol);
}

public virtual void Save(ModelSaveContext ctx)
void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx);

/// <summary>
/// Derived class, for example A, should overwrite <see cref="SaveModel"/> so that ((<see cref="ICanSaveModel"/>)A).Save(ctx) can correctly dump A.
/// </summary>
private protected virtual void SaveModel(ModelSaveContext ctx)
{
// *** Binary format **
// int: Id of the score column name
Original file line number Diff line number Diff line change
@@ -631,7 +631,7 @@ public static MultiClassPerInstanceEvaluator Create(IHostEnvironment env, ModelL
return new MultiClassPerInstanceEvaluator(env, ctx, schema);
}

public override void Save(ModelSaveContext ctx)
private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
@@ -642,7 +642,7 @@ public override void Save(ModelSaveContext ctx)
// int: number of classes
// int[]: Ids of the class names

base.Save(ctx);
base.SaveModel(ctx);
Host.Assert(_numClasses > 0);
ctx.Writer.Write(_numClasses);
for (int i = 0; i < _numClasses; i++)
Original file line number Diff line number Diff line change
@@ -426,15 +426,15 @@ public static MultiOutputRegressionPerInstanceEvaluator Create(IHostEnvironment
return new MultiOutputRegressionPerInstanceEvaluator(env, ctx, schema);
}

public override void Save(ModelSaveContext ctx)
private protected override void SaveModel(ModelSaveContext ctx)
{
Contracts.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());

// *** Binary format **
// base
base.Save(ctx);
base.SaveModel(ctx);
}

private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> activeOutput)
Original file line number Diff line number Diff line change
@@ -324,7 +324,7 @@ public static QuantileRegressionPerInstanceEvaluator Create(IHostEnvironment env
return new QuantileRegressionPerInstanceEvaluator(env, ctx, schema);
}

public override void Save(ModelSaveContext ctx)
private protected override void SaveModel(ModelSaveContext ctx)
{
Contracts.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
@@ -335,7 +335,7 @@ public override void Save(ModelSaveContext ctx)
// int: _scoreSize
// int[]: Ids of the quantile names

base.Save(ctx);
base.SaveModel(ctx);
Host.Assert(_scoreSize > 0);
ctx.Writer.Write(_scoreSize);
var quantiles = _quantiles.GetValues();
8 changes: 4 additions & 4 deletions src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs
Original file line number Diff line number Diff line change
@@ -597,11 +597,11 @@ public static RankerPerInstanceTransform Create(IHostEnvironment env, ModelLoadC
return h.Apply("Loading Model", ch => new RankerPerInstanceTransform(h, ctx, input));
}

public void Save(ModelSaveContext ctx)
void ICanSaveModel.Save(ModelSaveContext ctx)
{
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());
_transform.Save(ctx);
((ICanSaveModel)_transform).Save(ctx);
}

public long? GetRowCount()
@@ -715,7 +715,7 @@ public Transform(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
_bindings = new Bindings(Host, input.Schema, false, LabelCol, ScoreCol, GroupCol, _truncationLevel);
}

public override void Save(ModelSaveContext ctx)
private protected override void SaveModel(ModelSaveContext ctx)
{
Host.AssertValue(ctx);

@@ -725,7 +725,7 @@ public override void Save(ModelSaveContext ctx)
// int: _labelGains.Length
// double[]: _labelGains

base.Save(ctx);
base.SaveModel(ctx);
Host.Assert(0 < _truncationLevel && _truncationLevel < 100);
ctx.Writer.Write(_truncationLevel);
ctx.Writer.WriteDoubleArray(_labelGains);
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs
Original file line number Diff line number Diff line change
@@ -234,15 +234,15 @@ public static RegressionPerInstanceEvaluator Create(IHostEnvironment env, ModelL
return new RegressionPerInstanceEvaluator(env, ctx, schema);
}

public override void Save(ModelSaveContext ctx)
private protected override void SaveModel(ModelSaveContext ctx)
{
Contracts.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());

// *** Binary format **
// base
base.Save(ctx);
base.SaveModel(ctx);
}

private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> activeOutput)
Loading