diff --git a/src/Microsoft.ML.Core/Data/ITransformModel.cs b/src/Microsoft.ML.Core/Data/ITransformModel.cs
index ccc73265ec..b5a20c8dfe 100644
--- a/src/Microsoft.ML.Core/Data/ITransformModel.cs
+++ b/src/Microsoft.ML.Core/Data/ITransformModel.cs
@@ -54,5 +54,10 @@ public interface ITransformModel
/// The transform model as an . If not all transforms
/// in the pipeline are then it returns null.
IRowToRowMapper AsRowToRowMapper(IExceptionContext ectx);
+
+ ///
+ /// Get the loader information from the model.
+ ///
+ IDataView GetLoader();
}
}
diff --git a/src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs b/src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs
index de0a8208b0..4446b9e773 100644
--- a/src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs
+++ b/src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs
@@ -128,8 +128,9 @@ public static Output MakeScoringTransform(IHostEnvironment env, ModelInput input
IPredictor predictor;
RoleMappedData data;
- var emptyData = new EmptyDataView(host, input.PredictorModel.TransformModel.InputSchema);
- input.PredictorModel.PrepareData(host, emptyData, out data, out predictor);
+ var loader = input.PredictorModel.TransformModel.GetLoader();
+ var dataview = loader ?? new EmptyDataView(host, input.PredictorModel.TransformModel.InputSchema);
+ input.PredictorModel.PrepareData(host, dataview, out data, out predictor);
IDataView scoredPipe;
using (var ch = host.Start("Creating scoring pipeline"))
@@ -147,7 +148,7 @@ public static Output MakeScoringTransform(IHostEnvironment env, ModelInput input
return new Output
{
ScoredData = scoredPipe,
- ScoringTransform = new TransformModel(host, scoredPipe, emptyData)
+ ScoringTransform = new TransformModel(host, scoredPipe, dataview)
};
}
}
diff --git a/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs b/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs
index 9edc87df6d..e99ca0251f 100644
--- a/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs
+++ b/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs
@@ -57,7 +57,7 @@ public TransformModel(IHostEnvironment env, IDataView result, IDataView input)
env.CheckValue(result, nameof(result));
env.CheckValue(input, nameof(input));
- var root = new EmptyDataView(env, input.Schema);
+ var root = input is IDataLoader ? input : new EmptyDataView(env, input.Schema);
_schemaRoot = root.Schema;
_chain = ApplyTransformUtils.ApplyAllTransformsToData(env, result, root, input);
}
@@ -171,7 +171,7 @@ public void Save(IHostEnvironment env, Stream stream)
using (var rep = RepositoryWriter.CreateNew(stream, ch))
{
ch.Trace("Saving root schema and transformations");
- TrainUtils.SaveDataPipe(env, rep, _chain, blankLoader: true);
+ TrainUtils.SaveDataPipe(env, rep, _chain, blankLoader: false);
rep.Commit();
}
ch.Done();
@@ -186,6 +186,20 @@ public IRowToRowMapper AsRowToRowMapper(IExceptionContext ectx)
: null;
}
+ public IDataView GetLoader()
+ {
+ // Find the root schema.
+ for (IDataView view = _chain; ;)
+ {
+ var xf = view as IDataTransform;
+ if (xf == null)
+ {
+ return view;
+ }
+ view = xf.Source;
+ }
+ }
+
private sealed class CompositeRowToRowMapper : IRowToRowMapper
{
private readonly IDataView _chain;