diff --git a/src/Microsoft.ML.Core/Data/ITransformModel.cs b/src/Microsoft.ML.Core/Data/ITransformModel.cs index ccc73265ec..b5a20c8dfe 100644 --- a/src/Microsoft.ML.Core/Data/ITransformModel.cs +++ b/src/Microsoft.ML.Core/Data/ITransformModel.cs @@ -54,5 +54,10 @@ public interface ITransformModel /// The transform model as an . If not all transforms /// in the pipeline are then it returns null. IRowToRowMapper AsRowToRowMapper(IExceptionContext ectx); + + /// + /// Get the loader information from the model. + /// + IDataView GetLoader(); } } diff --git a/src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs b/src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs index de0a8208b0..4446b9e773 100644 --- a/src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs +++ b/src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs @@ -128,8 +128,9 @@ public static Output MakeScoringTransform(IHostEnvironment env, ModelInput input IPredictor predictor; RoleMappedData data; - var emptyData = new EmptyDataView(host, input.PredictorModel.TransformModel.InputSchema); - input.PredictorModel.PrepareData(host, emptyData, out data, out predictor); + var loader = input.PredictorModel.TransformModel.GetLoader(); + var dataview = loader ?? new EmptyDataView(host, input.PredictorModel.TransformModel.InputSchema); + input.PredictorModel.PrepareData(host, dataview, out data, out predictor); IDataView scoredPipe; using (var ch = host.Start("Creating scoring pipeline")) @@ -147,7 +148,7 @@ public static Output MakeScoringTransform(IHostEnvironment env, ModelInput input return new Output { ScoredData = scoredPipe, - ScoringTransform = new TransformModel(host, scoredPipe, emptyData) + ScoringTransform = new TransformModel(host, scoredPipe, dataview) }; } } diff --git a/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs b/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs index 9edc87df6d..e99ca0251f 100644 --- a/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs +++ b/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs @@ -57,7 +57,7 @@ public TransformModel(IHostEnvironment env, IDataView result, IDataView input) env.CheckValue(result, nameof(result)); env.CheckValue(input, nameof(input)); - var root = new EmptyDataView(env, input.Schema); + var root = input is IDataLoader ? input : new EmptyDataView(env, input.Schema); _schemaRoot = root.Schema; _chain = ApplyTransformUtils.ApplyAllTransformsToData(env, result, root, input); } @@ -171,7 +171,7 @@ public void Save(IHostEnvironment env, Stream stream) using (var rep = RepositoryWriter.CreateNew(stream, ch)) { ch.Trace("Saving root schema and transformations"); - TrainUtils.SaveDataPipe(env, rep, _chain, blankLoader: true); + TrainUtils.SaveDataPipe(env, rep, _chain, blankLoader: false); rep.Commit(); } ch.Done(); @@ -186,6 +186,20 @@ public IRowToRowMapper AsRowToRowMapper(IExceptionContext ectx) : null; } + public IDataView GetLoader() + { + // Find the root schema. + for (IDataView view = _chain; ;) + { + var xf = view as IDataTransform; + if (xf == null) + { + return view; + } + view = xf.Source; + } + } + private sealed class CompositeRowToRowMapper : IRowToRowMapper { private readonly IDataView _chain;