-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Transformer for Concat #896
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,18 +4,11 @@ | |
|
||
using Microsoft.ML.Core.Data; | ||
using Microsoft.ML.Data.StaticPipe.Runtime; | ||
using Microsoft.ML.Runtime; | ||
using Microsoft.ML.Runtime.Data; | ||
using Microsoft.ML.Runtime.Data.IO; | ||
using Microsoft.ML.Runtime.Internal.Utilities; | ||
using Microsoft.ML.Runtime.Model; | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
|
||
[assembly: LoadableClass(typeof(ConcatTransformer), null, typeof(SignatureLoadModel), | ||
"Concat Transformer Wrapper", ConcatTransformer.LoaderSignature)] | ||
|
||
namespace Microsoft.ML.Runtime.Data | ||
{ | ||
public sealed class ConcatEstimator : IEstimator<ITransformer> | ||
|
@@ -41,11 +34,7 @@ public ConcatEstimator(IHostEnvironment env, string name, params string[] source | |
public ITransformer Fit(IDataView input) | ||
{ | ||
_host.CheckValue(input, nameof(input)); | ||
|
||
var xf = new ConcatTransform(_host, input, _name, _source); | ||
var empty = new EmptyDataView(_host, input.Schema); | ||
var chunk = ApplyTransformUtils.ApplyAllTransformsToData(_host, xf, empty, input); | ||
return new ConcatTransformer(_host, chunk); | ||
return new ConcatTransform(_host, _name, _source); | ||
} | ||
|
||
private bool HasCategoricals(SchemaShape.Column col) | ||
|
@@ -123,90 +112,6 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) | |
} | ||
} | ||
|
||
// REVIEW: Note that the presence of this thing is a temporary measure only. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
People spend time writing this code, and you deleting it without any remorse! #ByDesign There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
// If it is cleaned up by code complete so much the better, but if not we will | ||
// have to wait a little bit. | ||
internal sealed class ConcatTransformer : ITransformer, ICanSaveModel | ||
{ | ||
public const string LoaderSignature = "ConcatTransformWrapper"; | ||
private const string TransformDirTemplate = "Step_{0:000}"; | ||
|
||
private readonly IHostEnvironment _env; | ||
private readonly IDataView _xf; | ||
|
||
internal ConcatTransformer(IHostEnvironment env, IDataView xf) | ||
{ | ||
_env = env; | ||
_xf = xf; | ||
} | ||
|
||
public ISchema GetOutputSchema(ISchema inputSchema) | ||
{ | ||
var dv = new EmptyDataView(_env, inputSchema); | ||
var output = ApplyTransformUtils.ApplyAllTransformsToData(_env, _xf, dv); | ||
return output.Schema; | ||
} | ||
|
||
public void Save(ModelSaveContext ctx) | ||
{ | ||
ctx.CheckAtModel(); | ||
ctx.SetVersionInfo(GetVersionInfo()); | ||
|
||
var dataPipe = _xf; | ||
var transforms = new List<IDataTransform>(); | ||
while (dataPipe is IDataTransform xf) | ||
{ | ||
// REVIEW: a malicious user could construct a loop in the Source chain, that would | ||
// cause this method to iterate forever (and throw something when the list overflows). There's | ||
// no way to insulate from ALL malicious behavior. | ||
transforms.Add(xf); | ||
dataPipe = xf.Source; | ||
Contracts.AssertValue(dataPipe); | ||
} | ||
transforms.Reverse(); | ||
|
||
ctx.SaveSubModel("Loader", c => BinaryLoader.SaveInstance(_env, c, dataPipe.Schema)); | ||
|
||
ctx.Writer.Write(transforms.Count); | ||
for (int i = 0; i < transforms.Count; i++) | ||
{ | ||
var dirName = string.Format(TransformDirTemplate, i); | ||
ctx.SaveModel(transforms[i], dirName); | ||
} | ||
} | ||
|
||
private static VersionInfo GetVersionInfo() | ||
{ | ||
return new VersionInfo( | ||
modelSignature: "CCATWRPR", | ||
verWrittenCur: 0x00010001, // Initial | ||
verReadableCur: 0x00010001, | ||
verWeCanReadBack: 0x00010001, | ||
loaderSignature: LoaderSignature); | ||
} | ||
|
||
public ConcatTransformer(IHostEnvironment env, ModelLoadContext ctx) | ||
{ | ||
ctx.CheckAtModel(GetVersionInfo()); | ||
int n = ctx.Reader.ReadInt32(); | ||
|
||
ctx.LoadModel<IDataLoader, SignatureLoadDataLoader>(env, out var loader, "Loader", new MultiFileSource(null)); | ||
|
||
IDataView data = loader; | ||
for (int i = 0; i < n; i++) | ||
{ | ||
var dirName = string.Format(TransformDirTemplate, i); | ||
ctx.LoadModel<IDataTransform, SignatureLoadDataTransform>(env, out var xf, dirName, data); | ||
data = xf; | ||
} | ||
|
||
_env = env; | ||
_xf = data; | ||
} | ||
|
||
public IDataView Transform(IDataView input) => ApplyTransformUtils.ApplyAllTransformsToData(_env, _xf, input); | ||
} | ||
|
||
/// <summary> | ||
/// The extension methods and implementation support for concatenating columns together. | ||
/// </summary> | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would this stop double mapper creation which we have right now for each RowToRowMapperTransform instantiation? #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think yes
In reply to: 217230609 [](ancestors = 217230609)