Skip to content

Commit 3a7d3da

Browse files
committed
post rebase
1 parent 6f49246 commit 3a7d3da

File tree

1 file changed

+67
-64
lines changed

1 file changed

+67
-64
lines changed

src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs

Lines changed: 67 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ namespace Microsoft.ML.Data
2727
/// </summary>
2828
[BestFriend]
2929
internal interface IRowMapper : ICanSaveModel
30-
{
30+
{
3131
/// <summary>
3232
/// Returns the input columns needed for the requested output columns.
3333
/// </summary>
@@ -54,7 +54,10 @@ internal interface IRowMapper : ICanSaveModel
5454
/// Returns parent transfomer which uses this mapper.
5555
/// </summary>
5656
ITransformer GetTransformer();
57-
}
57+
}
58+
59+
[BestFriend]
60+
internal delegate void SignatureLoadRowMapper(ModelLoadContext ctx, Schema schema);
5861

5962
/// <summary>
6063
/// This class is a transform that can add any number of output columns, that depend on any number of input columns.
@@ -64,7 +67,7 @@ internal interface IRowMapper : ICanSaveModel
6467
[BestFriend]
6568
internal sealed class RowToRowMapperTransform : RowToRowTransformBase, IRowToRowMapper,
6669
ITransformCanSaveOnnx, ITransformCanSavePfa, ITransformTemplate
67-
{
70+
{
6871
private readonly IRowMapper _mapper;
6972
private readonly ColumnBindings _bindings;
7073

@@ -74,15 +77,15 @@ internal sealed class RowToRowMapperTransform : RowToRowTransformBase, IRowToRow
7477
public const string RegistrationName = "RowToRowMapperTransform";
7578
public const string LoaderSignature = "RowToRowMapper";
7679
private static VersionInfo GetVersionInfo()
77-
{
80+
{
7881
return new VersionInfo(
7982
modelSignature: "ROW MPPR",
8083
verWrittenCur: 0x00010001, // Initial
8184
verReadableCur: 0x00010001,
8285
verWeCanReadBack: 0x00010001,
8386
loaderSignature: LoaderSignature,
8487
loaderAssemblyName: typeof(RowToRowMapperTransform).Assembly.FullName);
85-
}
88+
}
8689

8790
public override Schema OutputSchema => _bindings.Schema;
8891

@@ -92,44 +95,44 @@ private static VersionInfo GetVersionInfo()
9295

9396
public RowToRowMapperTransform(IHostEnvironment env, IDataView input, IRowMapper mapper, Func<Schema, IRowMapper> mapperFactory)
9497
: base(env, RegistrationName, input)
95-
{
98+
{
9699
Contracts.CheckValue(mapper, nameof(mapper));
97100
Contracts.CheckValueOrNull(mapperFactory);
98101
_mapper = mapper;
99102
_mapperFactory = mapperFactory;
100103
_bindings = new ColumnBindings(input.Schema, mapper.GetOutputColumns());
101-
}
104+
}
102105

103106
[BestFriend]
104107
internal static Schema GetOutputSchema(Schema inputSchema, IRowMapper mapper)
105-
{
108+
{
106109
Contracts.CheckValue(inputSchema, nameof(inputSchema));
107110
Contracts.CheckValue(mapper, nameof(mapper));
108111
return new ColumnBindings(inputSchema, mapper.GetOutputColumns()).Schema;
109-
}
112+
}
110113

111114
private RowToRowMapperTransform(IHost host, ModelLoadContext ctx, IDataView input)
112115
: base(host, input)
113-
{
116+
{
114117
// *** Binary format ***
115118
// _mapper
116119

117120
ctx.LoadModel<IRowMapper, SignatureLoadRowMapper>(host, out _mapper, "Mapper", input.Schema);
118121
_bindings = new ColumnBindings(input.Schema, _mapper.GetOutputColumns());
119-
}
122+
}
120123

121124
public static RowToRowMapperTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
122-
{
125+
{
123126
Contracts.CheckValue(env, nameof(env));
124127
var h = env.Register(RegistrationName);
125128
h.CheckValue(ctx, nameof(ctx));
126129
ctx.CheckAtModel(GetVersionInfo());
127130
h.CheckValue(input, nameof(input));
128131
return h.Apply("Loading Model", ch => new RowToRowMapperTransform(h, ctx, input));
129-
}
132+
}
130133

131134
private protected override void SaveModel(ModelSaveContext ctx)
132-
{
135+
{
133136
Host.CheckValue(ctx, nameof(ctx));
134137
ctx.CheckAtModel();
135138
ctx.SetVersionInfo(GetVersionInfo());
@@ -138,14 +141,14 @@ private protected override void SaveModel(ModelSaveContext ctx)
138141
// _mapper
139142

140143
ctx.SaveModel(_mapper, "Mapper");
141-
}
144+
}
142145

143146
/// <summary>
144147
/// Produces the set of active columns for the data view (as a bool[] of length bindings.ColumnCount),
145148
/// and the needed active input columns, given a predicate for the needed active output columns.
146149
/// </summary>
147150
private bool[] GetActive(Func<int, bool> predicate, out IEnumerable<Schema.Column> inputColumns)
148-
{
151+
{
149152
int n = _bindings.Schema.Count;
150153
var active = Utils.BuildArray(n, predicate);
151154
Contracts.Assert(active.Length == n);
@@ -163,10 +166,10 @@ private bool[] GetActive(Func<int, bool> predicate, out IEnumerable<Schema.Colum
163166
inputColumns = _bindings.InputSchema.Where(col => activeInput[col.Index] || predicateIn(col.Index));
164167

165168
return active;
166-
}
169+
}
167170

168171
private Func<int, bool> GetActiveOutputColumns(bool[] active)
169-
{
172+
{
170173
Contracts.AssertValue(active);
171174
Contracts.Assert(active.Length == _bindings.Schema.Count);
172175

@@ -176,26 +179,26 @@ private Func<int, bool> GetActiveOutputColumns(bool[] active)
176179
Contracts.Assert(0 <= col && col < _bindings.AddedColumnIndices.Count);
177180
return 0 <= col && col < _bindings.AddedColumnIndices.Count && active[_bindings.AddedColumnIndices[col]];
178181
};
179-
}
182+
}
180183

181184
protected override bool? ShouldUseParallelCursors(Func<int, bool> predicate)
182-
{
185+
{
183186
Host.AssertValue(predicate, "predicate");
184187
if (_bindings.AddedColumnIndices.Any(predicate))
185188
return true;
186189
return null;
187-
}
190+
}
188191

189192
protected override RowCursor GetRowCursorCore(IEnumerable<Schema.Column> columnsNeeded, Random rand = null)
190-
{
193+
{
191194
var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, OutputSchema);
192195
var active = GetActive(predicate, out IEnumerable<Schema.Column> inputCols);
193196

194197
return new Cursor(Host, Source.GetRowCursor(inputCols, rand), this, active);
195-
}
198+
}
196199

197200
public override RowCursor[] GetRowCursorSet(IEnumerable<Schema.Column> columnsNeeded, int n, Random rand = null)
198-
{
201+
{
199202
Host.CheckValueOrNull(rand);
200203

201204
var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, OutputSchema);
@@ -212,89 +215,89 @@ public override RowCursor[] GetRowCursorSet(IEnumerable<Schema.Column> columnsNe
212215
for (int i = 0; i < inputs.Length; i++)
213216
cursors[i] = new Cursor(Host, inputs[i], this, active);
214217
return cursors;
215-
}
218+
}
216219

217220
void ISaveAsOnnx.SaveAsOnnx(OnnxContext ctx)
218-
{
221+
{
219222
Host.CheckValue(ctx, nameof(ctx));
220223
if (_mapper is ISaveAsOnnx onnx)
221-
{
224+
{
222225
Host.Check(onnx.CanSaveOnnx(ctx), "Cannot be saved as ONNX.");
223226
onnx.SaveAsOnnx(ctx);
224-
}
225227
}
228+
}
226229

227230
void ISaveAsPfa.SaveAsPfa(BoundPfaContext ctx)
228-
{
231+
{
229232
Host.CheckValue(ctx, nameof(ctx));
230233
if (_mapper is ISaveAsPfa pfa)
231-
{
234+
{
232235
Host.Check(pfa.CanSavePfa, "Cannot be saved as PFA.");
233236
pfa.SaveAsPfa(ctx);
234-
}
235237
}
238+
}
236239

237240
/// <summary>
238241
/// Given a set of output columns, return the input columns that are needed to generate those output columns.
239242
/// </summary>
240243
IEnumerable<Schema.Column> IRowToRowMapper.GetDependencies(IEnumerable<Schema.Column> dependingColumns)
241-
{
244+
{
242245
var predicate = RowCursorUtils.FromColumnsToPredicate(dependingColumns, OutputSchema);
243246
GetActive(predicate, out var inputColumns);
244247
return inputColumns;
245-
}
248+
}
246249

247250
public Schema InputSchema => Source.Schema;
248251

249252
public Row GetRow(Row input, Func<int, bool> active)
250-
{
253+
{
251254
Host.CheckValue(input, nameof(input));
252255
Host.CheckValue(active, nameof(active));
253256
Host.Check(input.Schema == Source.Schema, "Schema of input row must be the same as the schema the mapper is bound to");
254257

255258
using (var ch = Host.Start("GetEntireRow"))
256-
{
259+
{
257260
var activeArr = new bool[OutputSchema.Count];
258261
for (int i = 0; i < OutputSchema.Count; i++)
259262
activeArr[i] = active(i);
260263
var pred = GetActiveOutputColumns(activeArr);
261264
var getters = _mapper.CreateGetters(input, pred, out Action disp);
262265
return new RowImpl(input, this, OutputSchema, getters, disp);
263-
}
264266
}
267+
}
265268

266269
IDataTransform ITransformTemplate.ApplyToData(IHostEnvironment env, IDataView newSource)
267-
{
270+
{
268271
Contracts.CheckValue(env, nameof(env));
269272

270273
Contracts.CheckValue(newSource, nameof(newSource));
271274
if (_mapperFactory != null)
272-
{
275+
{
273276
var newMapper = _mapperFactory(newSource.Schema);
274277
return new RowToRowMapperTransform(env.Register(nameof(RowToRowMapperTransform)), newSource, newMapper, _mapperFactory);
275-
}
278+
}
276279
// Revert to serialization. This was how it worked in all the cases, now it's only when we can't re-create the mapper.
277280
using (var stream = new MemoryStream())
278-
{
281+
{
279282
using (var rep = RepositoryWriter.CreateNew(stream, env))
280-
{
283+
{
281284
ModelSaveContext.SaveModel(rep, this, "model");
282285
rep.Commit();
283-
}
286+
}
284287

285288
stream.Position = 0;
286289
using (var rep = RepositoryReader.Open(stream, env))
287-
{
290+
{
288291
IDataTransform newData;
289292
ModelLoadContext.LoadModel<IDataTransform, SignatureLoadDataTransform>(env,
290293
out newData, rep, "model", newSource);
291294
return newData;
292-
}
293295
}
294296
}
297+
}
295298

296299
private sealed class RowImpl : WrappingRow
297-
{
300+
{
298301
private readonly Delegate[] _getters;
299302
private readonly RowToRowMapperTransform _parent;
300303
private readonly Action _disposer;
@@ -303,21 +306,21 @@ private sealed class RowImpl : WrappingRow
303306

304307
public RowImpl(Row input, RowToRowMapperTransform parent, Schema schema, Delegate[] getters, Action disposer)
305308
: base(input)
306-
{
309+
{
307310
_parent = parent;
308311
Schema = schema;
309312
_getters = getters;
310313
_disposer = disposer;
311-
}
314+
}
312315

313316
protected override void DisposeCore(bool disposing)
314-
{
317+
{
315318
if (disposing)
316319
_disposer?.Invoke();
317-
}
320+
}
318321

319322
public override ValueGetter<TValue> GetGetter<TValue>(int col)
320-
{
323+
{
321324
bool isSrc;
322325
int index = _parent._bindings.MapColumnIndex(out isSrc, col);
323326
if (isSrc)
@@ -328,20 +331,20 @@ public override ValueGetter<TValue> GetGetter<TValue>(int col)
328331
if (fn == null)
329332
throw Contracts.Except("Invalid TValue in GetGetter: '{0}'", typeof(TValue));
330333
return fn;
331-
}
334+
}
332335

333336
public override bool IsColumnActive(int col)
334-
{
337+
{
335338
bool isSrc;
336339
int index = _parent._bindings.MapColumnIndex(out isSrc, col);
337340
if (isSrc)
338341
return Input.IsColumnActive((index));
339342
return _getters[index] != null;
340-
}
341343
}
344+
}
342345

343346
private sealed class Cursor : SynchronizedCursorBase
344-
{
347+
{
345348
private readonly Delegate[] _getters;
346349
private readonly bool[] _active;
347350
private readonly ColumnBindings _bindings;
@@ -352,21 +355,21 @@ private sealed class Cursor : SynchronizedCursorBase
352355

353356
public Cursor(IChannelProvider provider, RowCursor input, RowToRowMapperTransform parent, bool[] active)
354357
: base(provider, input)
355-
{
358+
{
356359
var pred = parent.GetActiveOutputColumns(active);
357360
_getters = parent._mapper.CreateGetters(input, pred, out _disposer);
358361
_active = active;
359362
_bindings = parent._bindings;
360-
}
363+
}
361364

362365
public override bool IsColumnActive(int col)
363-
{
366+
{
364367
Ch.Check(0 <= col && col < _bindings.Schema.Count);
365368
return _active[col];
366-
}
369+
}
367370

368371
public override ValueGetter<TValue> GetGetter<TValue>(int col)
369-
{
372+
{
370373
Ch.Check(IsColumnActive(col));
371374

372375
bool isSrc;
@@ -381,22 +384,22 @@ public override ValueGetter<TValue> GetGetter<TValue>(int col)
381384
if (fn == null)
382385
throw Ch.Except("Invalid TValue in GetGetter: '{0}'", typeof(TValue));
383386
return fn;
384-
}
387+
}
385388

386389
protected override void Dispose(bool disposing)
387-
{
390+
{
388391
if (_disposed)
389392
return;
390393
if (disposing)
391394
_disposer?.Invoke();
392395
_disposed = true;
393396
base.Dispose(disposing);
394-
}
395397
}
398+
}
396399

397400
internal ITransformer GetTransformer()
398-
{
401+
{
399402
return _mapper.GetTransformer();
400-
}
401403
}
402404
}
405+
}

0 commit comments

Comments
 (0)