Skip to content

Commit 6f49246

Browse files
committed
addressing comments from round 6
1 parent 1e4660e commit 6f49246

File tree

5 files changed

+71
-74
lines changed

5 files changed

+71
-74
lines changed

src/Microsoft.ML.Data/DataView/CompositeRowToRowMapper.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ public Row GetRow(Row input, Func<int, bool> active)
7878
{
7979
var outputColumns = InnerMappers[i].OutputSchema.Where(c => deps[i](c.Index));
8080
var cols = InnerMappers[i].GetDependencies(outputColumns).ToArray();
81-
deps[i - 1] = c => cols.Count() > 0 ? cols.Any(col => col.Index == c) : false;
81+
deps[i - 1] = c => cols.Length > 0 ? cols.Any(col => col.Index == c) : false;
8282
}
8383

8484
Row result = input;

src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs

Lines changed: 68 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ namespace Microsoft.ML.Data
2727
/// </summary>
2828
[BestFriend]
2929
internal interface IRowMapper : ICanSaveModel
30-
{
30+
{
3131
/// <summary>
3232
/// Returns the input columns needed for the requested output columns.
3333
/// </summary>
@@ -54,9 +54,7 @@ internal interface IRowMapper : ICanSaveModel
5454
/// Returns parent transfomer which uses this mapper.
5555
/// </summary>
5656
ITransformer GetTransformer();
57-
}
58-
[BestFriend]
59-
internal delegate void SignatureLoadRowMapper(ModelLoadContext ctx, Schema schema);
57+
}
6058

6159
/// <summary>
6260
/// This class is a transform that can add any number of output columns, that depend on any number of input columns.
@@ -66,7 +64,7 @@ internal interface IRowMapper : ICanSaveModel
6664
[BestFriend]
6765
internal sealed class RowToRowMapperTransform : RowToRowTransformBase, IRowToRowMapper,
6866
ITransformCanSaveOnnx, ITransformCanSavePfa, ITransformTemplate
69-
{
67+
{
7068
private readonly IRowMapper _mapper;
7169
private readonly ColumnBindings _bindings;
7270

@@ -76,15 +74,15 @@ internal sealed class RowToRowMapperTransform : RowToRowTransformBase, IRowToRow
7674
public const string RegistrationName = "RowToRowMapperTransform";
7775
public const string LoaderSignature = "RowToRowMapper";
7876
private static VersionInfo GetVersionInfo()
79-
{
77+
{
8078
return new VersionInfo(
8179
modelSignature: "ROW MPPR",
8280
verWrittenCur: 0x00010001, // Initial
8381
verReadableCur: 0x00010001,
8482
verWeCanReadBack: 0x00010001,
8583
loaderSignature: LoaderSignature,
8684
loaderAssemblyName: typeof(RowToRowMapperTransform).Assembly.FullName);
87-
}
85+
}
8886

8987
public override Schema OutputSchema => _bindings.Schema;
9088

@@ -94,43 +92,44 @@ private static VersionInfo GetVersionInfo()
9492

9593
public RowToRowMapperTransform(IHostEnvironment env, IDataView input, IRowMapper mapper, Func<Schema, IRowMapper> mapperFactory)
9694
: base(env, RegistrationName, input)
97-
{
95+
{
9896
Contracts.CheckValue(mapper, nameof(mapper));
9997
Contracts.CheckValueOrNull(mapperFactory);
10098
_mapper = mapper;
10199
_mapperFactory = mapperFactory;
102100
_bindings = new ColumnBindings(input.Schema, mapper.GetOutputColumns());
103-
}
101+
}
104102

105-
public static Schema GetOutputSchema(Schema inputSchema, IRowMapper mapper)
106-
{
103+
[BestFriend]
104+
internal static Schema GetOutputSchema(Schema inputSchema, IRowMapper mapper)
105+
{
107106
Contracts.CheckValue(inputSchema, nameof(inputSchema));
108107
Contracts.CheckValue(mapper, nameof(mapper));
109108
return new ColumnBindings(inputSchema, mapper.GetOutputColumns()).Schema;
110-
}
109+
}
111110

112111
private RowToRowMapperTransform(IHost host, ModelLoadContext ctx, IDataView input)
113112
: base(host, input)
114-
{
113+
{
115114
// *** Binary format ***
116115
// _mapper
117116

118117
ctx.LoadModel<IRowMapper, SignatureLoadRowMapper>(host, out _mapper, "Mapper", input.Schema);
119118
_bindings = new ColumnBindings(input.Schema, _mapper.GetOutputColumns());
120-
}
119+
}
121120

122121
public static RowToRowMapperTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
123-
{
122+
{
124123
Contracts.CheckValue(env, nameof(env));
125124
var h = env.Register(RegistrationName);
126125
h.CheckValue(ctx, nameof(ctx));
127126
ctx.CheckAtModel(GetVersionInfo());
128127
h.CheckValue(input, nameof(input));
129128
return h.Apply("Loading Model", ch => new RowToRowMapperTransform(h, ctx, input));
130-
}
129+
}
131130

132131
private protected override void SaveModel(ModelSaveContext ctx)
133-
{
132+
{
134133
Host.CheckValue(ctx, nameof(ctx));
135134
ctx.CheckAtModel();
136135
ctx.SetVersionInfo(GetVersionInfo());
@@ -139,14 +138,14 @@ private protected override void SaveModel(ModelSaveContext ctx)
139138
// _mapper
140139

141140
ctx.SaveModel(_mapper, "Mapper");
142-
}
141+
}
143142

144143
/// <summary>
145144
/// Produces the set of active columns for the data view (as a bool[] of length bindings.ColumnCount),
146145
/// and the needed active input columns, given a predicate for the needed active output columns.
147146
/// </summary>
148147
private bool[] GetActive(Func<int, bool> predicate, out IEnumerable<Schema.Column> inputColumns)
149-
{
148+
{
150149
int n = _bindings.Schema.Count;
151150
var active = Utils.BuildArray(n, predicate);
152151
Contracts.Assert(active.Length == n);
@@ -161,13 +160,13 @@ private bool[] GetActive(Func<int, bool> predicate, out IEnumerable<Schema.Colum
161160
var predicateIn = _mapper.GetDependencies(predicateOut);
162161

163162
// Combine the two sets of input columns.
164-
inputColumns = _bindings.InputSchema.Where(col => activeInput[col.Index]|| predicateIn(col.Index));
163+
inputColumns = _bindings.InputSchema.Where(col => activeInput[col.Index] || predicateIn(col.Index));
165164

166165
return active;
167-
}
166+
}
168167

169168
private Func<int, bool> GetActiveOutputColumns(bool[] active)
170-
{
169+
{
171170
Contracts.AssertValue(active);
172171
Contracts.Assert(active.Length == _bindings.Schema.Count);
173172

@@ -177,26 +176,26 @@ private Func<int, bool> GetActiveOutputColumns(bool[] active)
177176
Contracts.Assert(0 <= col && col < _bindings.AddedColumnIndices.Count);
178177
return 0 <= col && col < _bindings.AddedColumnIndices.Count && active[_bindings.AddedColumnIndices[col]];
179178
};
180-
}
179+
}
181180

182181
protected override bool? ShouldUseParallelCursors(Func<int, bool> predicate)
183-
{
182+
{
184183
Host.AssertValue(predicate, "predicate");
185184
if (_bindings.AddedColumnIndices.Any(predicate))
186185
return true;
187186
return null;
188-
}
187+
}
189188

190189
protected override RowCursor GetRowCursorCore(IEnumerable<Schema.Column> columnsNeeded, Random rand = null)
191-
{
190+
{
192191
var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, OutputSchema);
193192
var active = GetActive(predicate, out IEnumerable<Schema.Column> inputCols);
194193

195194
return new Cursor(Host, Source.GetRowCursor(inputCols, rand), this, active);
196-
}
195+
}
197196

198197
public override RowCursor[] GetRowCursorSet(IEnumerable<Schema.Column> columnsNeeded, int n, Random rand = null)
199-
{
198+
{
200199
Host.CheckValueOrNull(rand);
201200

202201
var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, OutputSchema);
@@ -213,89 +212,89 @@ public override RowCursor[] GetRowCursorSet(IEnumerable<Schema.Column> columnsNe
213212
for (int i = 0; i < inputs.Length; i++)
214213
cursors[i] = new Cursor(Host, inputs[i], this, active);
215214
return cursors;
216-
}
215+
}
217216

218217
void ISaveAsOnnx.SaveAsOnnx(OnnxContext ctx)
219-
{
218+
{
220219
Host.CheckValue(ctx, nameof(ctx));
221220
if (_mapper is ISaveAsOnnx onnx)
222-
{
221+
{
223222
Host.Check(onnx.CanSaveOnnx(ctx), "Cannot be saved as ONNX.");
224223
onnx.SaveAsOnnx(ctx);
224+
}
225225
}
226-
}
227226

228227
void ISaveAsPfa.SaveAsPfa(BoundPfaContext ctx)
229-
{
228+
{
230229
Host.CheckValue(ctx, nameof(ctx));
231230
if (_mapper is ISaveAsPfa pfa)
232-
{
231+
{
233232
Host.Check(pfa.CanSavePfa, "Cannot be saved as PFA.");
234233
pfa.SaveAsPfa(ctx);
234+
}
235235
}
236-
}
237236

238237
/// <summary>
239238
/// Given a set of output columns, return the input columns that are needed to generate those output columns.
240239
/// </summary>
241240
IEnumerable<Schema.Column> IRowToRowMapper.GetDependencies(IEnumerable<Schema.Column> dependingColumns)
242-
{
241+
{
243242
var predicate = RowCursorUtils.FromColumnsToPredicate(dependingColumns, OutputSchema);
244-
GetActive(predicate, out IEnumerable<Schema.Column> inputColumns);
243+
GetActive(predicate, out var inputColumns);
245244
return inputColumns;
246-
}
245+
}
247246

248247
public Schema InputSchema => Source.Schema;
249248

250249
public Row GetRow(Row input, Func<int, bool> active)
251-
{
250+
{
252251
Host.CheckValue(input, nameof(input));
253252
Host.CheckValue(active, nameof(active));
254253
Host.Check(input.Schema == Source.Schema, "Schema of input row must be the same as the schema the mapper is bound to");
255254

256255
using (var ch = Host.Start("GetEntireRow"))
257-
{
256+
{
258257
var activeArr = new bool[OutputSchema.Count];
259258
for (int i = 0; i < OutputSchema.Count; i++)
260259
activeArr[i] = active(i);
261260
var pred = GetActiveOutputColumns(activeArr);
262261
var getters = _mapper.CreateGetters(input, pred, out Action disp);
263262
return new RowImpl(input, this, OutputSchema, getters, disp);
263+
}
264264
}
265-
}
266265

267266
IDataTransform ITransformTemplate.ApplyToData(IHostEnvironment env, IDataView newSource)
268-
{
267+
{
269268
Contracts.CheckValue(env, nameof(env));
270269

271270
Contracts.CheckValue(newSource, nameof(newSource));
272271
if (_mapperFactory != null)
273-
{
272+
{
274273
var newMapper = _mapperFactory(newSource.Schema);
275274
return new RowToRowMapperTransform(env.Register(nameof(RowToRowMapperTransform)), newSource, newMapper, _mapperFactory);
276-
}
275+
}
277276
// Revert to serialization. This was how it worked in all the cases, now it's only when we can't re-create the mapper.
278277
using (var stream = new MemoryStream())
279-
{
280-
using (var rep = RepositoryWriter.CreateNew(stream, env))
281278
{
279+
using (var rep = RepositoryWriter.CreateNew(stream, env))
280+
{
282281
ModelSaveContext.SaveModel(rep, this, "model");
283282
rep.Commit();
284-
}
283+
}
285284

286285
stream.Position = 0;
287286
using (var rep = RepositoryReader.Open(stream, env))
288-
{
287+
{
289288
IDataTransform newData;
290289
ModelLoadContext.LoadModel<IDataTransform, SignatureLoadDataTransform>(env,
291290
out newData, rep, "model", newSource);
292291
return newData;
292+
}
293293
}
294294
}
295-
}
296295

297296
private sealed class RowImpl : WrappingRow
298-
{
297+
{
299298
private readonly Delegate[] _getters;
300299
private readonly RowToRowMapperTransform _parent;
301300
private readonly Action _disposer;
@@ -304,21 +303,21 @@ private sealed class RowImpl : WrappingRow
304303

305304
public RowImpl(Row input, RowToRowMapperTransform parent, Schema schema, Delegate[] getters, Action disposer)
306305
: base(input)
307-
{
306+
{
308307
_parent = parent;
309308
Schema = schema;
310309
_getters = getters;
311310
_disposer = disposer;
312-
}
311+
}
313312

314313
protected override void DisposeCore(bool disposing)
315-
{
314+
{
316315
if (disposing)
317316
_disposer?.Invoke();
318-
}
317+
}
319318

320319
public override ValueGetter<TValue> GetGetter<TValue>(int col)
321-
{
320+
{
322321
bool isSrc;
323322
int index = _parent._bindings.MapColumnIndex(out isSrc, col);
324323
if (isSrc)
@@ -329,20 +328,20 @@ public override ValueGetter<TValue> GetGetter<TValue>(int col)
329328
if (fn == null)
330329
throw Contracts.Except("Invalid TValue in GetGetter: '{0}'", typeof(TValue));
331330
return fn;
332-
}
331+
}
333332

334333
public override bool IsColumnActive(int col)
335-
{
334+
{
336335
bool isSrc;
337336
int index = _parent._bindings.MapColumnIndex(out isSrc, col);
338337
if (isSrc)
339338
return Input.IsColumnActive((index));
340339
return _getters[index] != null;
340+
}
341341
}
342-
}
343342

344343
private sealed class Cursor : SynchronizedCursorBase
345-
{
344+
{
346345
private readonly Delegate[] _getters;
347346
private readonly bool[] _active;
348347
private readonly ColumnBindings _bindings;
@@ -353,21 +352,21 @@ private sealed class Cursor : SynchronizedCursorBase
353352

354353
public Cursor(IChannelProvider provider, RowCursor input, RowToRowMapperTransform parent, bool[] active)
355354
: base(provider, input)
356-
{
355+
{
357356
var pred = parent.GetActiveOutputColumns(active);
358357
_getters = parent._mapper.CreateGetters(input, pred, out _disposer);
359358
_active = active;
360359
_bindings = parent._bindings;
361-
}
360+
}
362361

363362
public override bool IsColumnActive(int col)
364-
{
363+
{
365364
Ch.Check(0 <= col && col < _bindings.Schema.Count);
366365
return _active[col];
367-
}
366+
}
368367

369368
public override ValueGetter<TValue> GetGetter<TValue>(int col)
370-
{
369+
{
371370
Ch.Check(IsColumnActive(col));
372371

373372
bool isSrc;
@@ -382,22 +381,22 @@ public override ValueGetter<TValue> GetGetter<TValue>(int col)
382381
if (fn == null)
383382
throw Ch.Except("Invalid TValue in GetGetter: '{0}'", typeof(TValue));
384383
return fn;
385-
}
384+
}
386385

387386
protected override void Dispose(bool disposing)
388-
{
387+
{
389388
if (_disposed)
390389
return;
391390
if (disposing)
392391
_disposer?.Invoke();
393392
_disposed = true;
394393
base.Dispose(disposing);
394+
}
395395
}
396-
}
397396

398397
internal ITransformer GetTransformer()
399-
{
398+
{
400399
return _mapper.GetTransformer();
400+
}
401401
}
402402
}
403-
}

0 commit comments

Comments
 (0)