Skip to content

Transformer for Concat #896

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 14, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,15 @@ public sealed class RowToRowMapperTransform : RowToRowTransformBase, IRowToRowMa
{
private sealed class Bindings : ColumnBindingsBase
{
private readonly RowToRowMapperTransform _parent;
private readonly IRowMapper _mapper;
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Sep 13, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IRowMapper [](start = 29, length = 10)

Would this stop double mapper creation which we have right now for each RowToRowMapperTransform instantiation? #Closed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think yes


In reply to: 217230609 [](ancestors = 217230609)

public readonly RowMapperColumnInfo[] OutputColInfos;

public Bindings(ISchema inputSchema, RowToRowMapperTransform parent)
: base(inputSchema, true, Contracts.CheckRef(parent, nameof(parent))._mapper.GetOutputColumns().Select(info => info.Name).ToArray())
public Bindings(ISchema inputSchema, IRowMapper mapper)
: base(inputSchema, true, Contracts.CheckRef(mapper, nameof(mapper)).GetOutputColumns().Select(info => info.Name).ToArray())
{
Contracts.AssertValue(parent);
_parent = parent;
OutputColInfos = _parent._mapper.GetOutputColumns().ToArray();
Contracts.AssertValue(mapper);
_mapper = mapper;
OutputColInfos = _mapper.GetOutputColumns().ToArray();
}

protected override ColumnType GetColumnTypeCore(int iinfo)
Expand All @@ -168,7 +168,7 @@ public bool[] GetActive(Func<int, bool> predicate, out Func<int, bool> predicate
var predicateOut = GetActiveOutputColumns(active);

// Now map those to active input columns.
var predicateIn = _parent._mapper.GetDependencies(predicateOut);
var predicateIn = _mapper.GetDependencies(predicateOut);

// Combine the two sets of input columns.
predicateInput =
Expand Down Expand Up @@ -255,7 +255,14 @@ public RowToRowMapperTransform(IHostEnvironment env, IDataView input, IRowMapper
{
Contracts.CheckValue(mapper, nameof(mapper));
_mapper = mapper;
_bindings = new Bindings(input.Schema, this);
_bindings = new Bindings(input.Schema, mapper);
}

public static ISchema GetOutputSchema(ISchema inputSchema, IRowMapper mapper)
{
Contracts.CheckValue(inputSchema, nameof(inputSchema));
Contracts.CheckValue(mapper, nameof(mapper));
return new Bindings(inputSchema, mapper);
}

private RowToRowMapperTransform(IHost host, ModelLoadContext ctx, IDataView input)
Expand All @@ -265,7 +272,7 @@ private RowToRowMapperTransform(IHost host, ModelLoadContext ctx, IDataView inpu
// _mapper

ctx.LoadModel<IRowMapper, SignatureLoadRowMapper>(host, out _mapper, "Mapper", input.Schema);
_bindings = new Bindings(input.Schema, this);
_bindings = new Bindings(input.Schema, _mapper);
}

public static RowToRowMapperTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/EntryPoints/SchemaManipulation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public static CommonOutputs.TransformOutput ConcatColumns(IHostEnvironment env,
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);

var xf = new ConcatTransform(env, input, input.Data);
var xf = ConcatTransform.Create(env, input, input.Data);
return new CommonOutputs.TransformOutput { Model = new TransformModel(env, xf, input.Data), OutputData = xf };
}

Expand Down
97 changes: 1 addition & 96 deletions src/Microsoft.ML.Data/Transforms/ConcatEstimator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,11 @@

using Microsoft.ML.Core.Data;
using Microsoft.ML.Data.StaticPipe.Runtime;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Data.IO;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Model;
using System;
using System.Collections.Generic;
using System.Linq;

[assembly: LoadableClass(typeof(ConcatTransformer), null, typeof(SignatureLoadModel),
"Concat Transformer Wrapper", ConcatTransformer.LoaderSignature)]

namespace Microsoft.ML.Runtime.Data
{
public sealed class ConcatEstimator : IEstimator<ITransformer>
Expand All @@ -41,11 +34,7 @@ public ConcatEstimator(IHostEnvironment env, string name, params string[] source
public ITransformer Fit(IDataView input)
{
_host.CheckValue(input, nameof(input));

var xf = new ConcatTransform(_host, input, _name, _source);
var empty = new EmptyDataView(_host, input.Schema);
var chunk = ApplyTransformUtils.ApplyAllTransformsToData(_host, xf, empty, input);
return new ConcatTransformer(_host, chunk);
return new ConcatTransform(_host, _name, _source);
}

private bool HasCategoricals(SchemaShape.Column col)
Expand Down Expand Up @@ -123,90 +112,6 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
}
}

// REVIEW: Note that the presence of this thing is a temporary measure only.
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Sep 13, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// REVIEW: Note that the presence of this thing is a temporary measure only. [](start = 4, length = 76)

People spend time writing this code, and you deleting it without any remorse! #ByDesign

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extra credit for removing Tom's code!


In reply to: 217230777 [](ancestors = 217230777)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a code falls in the forest...


In reply to: 217551182 [](ancestors = 217551182,217230777)

// If it is cleaned up by code complete so much the better, but if not we will
// have to wait a little bit.
internal sealed class ConcatTransformer : ITransformer, ICanSaveModel
{
public const string LoaderSignature = "ConcatTransformWrapper";
private const string TransformDirTemplate = "Step_{0:000}";

private readonly IHostEnvironment _env;
private readonly IDataView _xf;

internal ConcatTransformer(IHostEnvironment env, IDataView xf)
{
_env = env;
_xf = xf;
}

public ISchema GetOutputSchema(ISchema inputSchema)
{
var dv = new EmptyDataView(_env, inputSchema);
var output = ApplyTransformUtils.ApplyAllTransformsToData(_env, _xf, dv);
return output.Schema;
}

public void Save(ModelSaveContext ctx)
{
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());

var dataPipe = _xf;
var transforms = new List<IDataTransform>();
while (dataPipe is IDataTransform xf)
{
// REVIEW: a malicious user could construct a loop in the Source chain, that would
// cause this method to iterate forever (and throw something when the list overflows). There's
// no way to insulate from ALL malicious behavior.
transforms.Add(xf);
dataPipe = xf.Source;
Contracts.AssertValue(dataPipe);
}
transforms.Reverse();

ctx.SaveSubModel("Loader", c => BinaryLoader.SaveInstance(_env, c, dataPipe.Schema));

ctx.Writer.Write(transforms.Count);
for (int i = 0; i < transforms.Count; i++)
{
var dirName = string.Format(TransformDirTemplate, i);
ctx.SaveModel(transforms[i], dirName);
}
}

private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "CCATWRPR",
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature);
}

public ConcatTransformer(IHostEnvironment env, ModelLoadContext ctx)
{
ctx.CheckAtModel(GetVersionInfo());
int n = ctx.Reader.ReadInt32();

ctx.LoadModel<IDataLoader, SignatureLoadDataLoader>(env, out var loader, "Loader", new MultiFileSource(null));

IDataView data = loader;
for (int i = 0; i < n; i++)
{
var dirName = string.Format(TransformDirTemplate, i);
ctx.LoadModel<IDataTransform, SignatureLoadDataTransform>(env, out var xf, dirName, data);
data = xf;
}

_env = env;
_xf = data;
}

public IDataView Transform(IDataView input) => ApplyTransformUtils.ApplyAllTransformsToData(_env, _xf, input);
}

/// <summary>
/// The extension methods and implementation support for concatenating columns together.
/// </summary>
Expand Down
Loading