@@ -27,7 +27,7 @@ namespace Microsoft.ML.Data
27
27
/// </summary>
28
28
[ BestFriend ]
29
29
internal interface IRowMapper : ICanSaveModel
30
- {
30
+ {
31
31
/// <summary>
32
32
/// Returns the input columns needed for the requested output columns.
33
33
/// </summary>
@@ -54,9 +54,7 @@ internal interface IRowMapper : ICanSaveModel
54
54
/// Returns parent transfomer which uses this mapper.
55
55
/// </summary>
56
56
ITransformer GetTransformer ( ) ;
57
- }
58
- [ BestFriend ]
59
- internal delegate void SignatureLoadRowMapper ( ModelLoadContext ctx , Schema schema ) ;
57
+ }
60
58
61
59
/// <summary>
62
60
/// This class is a transform that can add any number of output columns, that depend on any number of input columns.
@@ -66,7 +64,7 @@ internal interface IRowMapper : ICanSaveModel
66
64
[ BestFriend ]
67
65
internal sealed class RowToRowMapperTransform : RowToRowTransformBase , IRowToRowMapper ,
68
66
ITransformCanSaveOnnx , ITransformCanSavePfa , ITransformTemplate
69
- {
67
+ {
70
68
private readonly IRowMapper _mapper ;
71
69
private readonly ColumnBindings _bindings ;
72
70
@@ -76,15 +74,15 @@ internal sealed class RowToRowMapperTransform : RowToRowTransformBase, IRowToRow
76
74
public const string RegistrationName = "RowToRowMapperTransform" ;
77
75
public const string LoaderSignature = "RowToRowMapper" ;
78
76
private static VersionInfo GetVersionInfo ( )
79
- {
77
+ {
80
78
return new VersionInfo (
81
79
modelSignature : "ROW MPPR" ,
82
80
verWrittenCur : 0x00010001 , // Initial
83
81
verReadableCur : 0x00010001 ,
84
82
verWeCanReadBack : 0x00010001 ,
85
83
loaderSignature : LoaderSignature ,
86
84
loaderAssemblyName : typeof ( RowToRowMapperTransform ) . Assembly . FullName ) ;
87
- }
85
+ }
88
86
89
87
public override Schema OutputSchema => _bindings . Schema ;
90
88
@@ -94,43 +92,44 @@ private static VersionInfo GetVersionInfo()
94
92
95
93
public RowToRowMapperTransform ( IHostEnvironment env , IDataView input , IRowMapper mapper , Func < Schema , IRowMapper > mapperFactory )
96
94
: base ( env , RegistrationName , input )
97
- {
95
+ {
98
96
Contracts . CheckValue ( mapper , nameof ( mapper ) ) ;
99
97
Contracts . CheckValueOrNull ( mapperFactory ) ;
100
98
_mapper = mapper ;
101
99
_mapperFactory = mapperFactory ;
102
100
_bindings = new ColumnBindings ( input . Schema , mapper . GetOutputColumns ( ) ) ;
103
- }
101
+ }
104
102
105
- public static Schema GetOutputSchema ( Schema inputSchema , IRowMapper mapper )
106
- {
103
+ [ BestFriend ]
104
+ internal static Schema GetOutputSchema ( Schema inputSchema , IRowMapper mapper )
105
+ {
107
106
Contracts . CheckValue ( inputSchema , nameof ( inputSchema ) ) ;
108
107
Contracts . CheckValue ( mapper , nameof ( mapper ) ) ;
109
108
return new ColumnBindings ( inputSchema , mapper . GetOutputColumns ( ) ) . Schema ;
110
- }
109
+ }
111
110
112
111
private RowToRowMapperTransform ( IHost host , ModelLoadContext ctx , IDataView input )
113
112
: base ( host , input )
114
- {
113
+ {
115
114
// *** Binary format ***
116
115
// _mapper
117
116
118
117
ctx . LoadModel < IRowMapper , SignatureLoadRowMapper > ( host , out _mapper , "Mapper" , input . Schema ) ;
119
118
_bindings = new ColumnBindings ( input . Schema , _mapper . GetOutputColumns ( ) ) ;
120
- }
119
+ }
121
120
122
121
public static RowToRowMapperTransform Create ( IHostEnvironment env , ModelLoadContext ctx , IDataView input )
123
- {
122
+ {
124
123
Contracts . CheckValue ( env , nameof ( env ) ) ;
125
124
var h = env . Register ( RegistrationName ) ;
126
125
h . CheckValue ( ctx , nameof ( ctx ) ) ;
127
126
ctx . CheckAtModel ( GetVersionInfo ( ) ) ;
128
127
h . CheckValue ( input , nameof ( input ) ) ;
129
128
return h . Apply ( "Loading Model" , ch => new RowToRowMapperTransform ( h , ctx , input ) ) ;
130
- }
129
+ }
131
130
132
131
private protected override void SaveModel ( ModelSaveContext ctx )
133
- {
132
+ {
134
133
Host . CheckValue ( ctx , nameof ( ctx ) ) ;
135
134
ctx . CheckAtModel ( ) ;
136
135
ctx . SetVersionInfo ( GetVersionInfo ( ) ) ;
@@ -139,14 +138,14 @@ private protected override void SaveModel(ModelSaveContext ctx)
139
138
// _mapper
140
139
141
140
ctx . SaveModel ( _mapper , "Mapper" ) ;
142
- }
141
+ }
143
142
144
143
/// <summary>
145
144
/// Produces the set of active columns for the data view (as a bool[] of length bindings.ColumnCount),
146
145
/// and the needed active input columns, given a predicate for the needed active output columns.
147
146
/// </summary>
148
147
private bool [ ] GetActive ( Func < int , bool > predicate , out IEnumerable < Schema . Column > inputColumns )
149
- {
148
+ {
150
149
int n = _bindings . Schema . Count ;
151
150
var active = Utils . BuildArray ( n , predicate ) ;
152
151
Contracts . Assert ( active . Length == n ) ;
@@ -161,13 +160,13 @@ private bool[] GetActive(Func<int, bool> predicate, out IEnumerable<Schema.Colum
161
160
var predicateIn = _mapper . GetDependencies ( predicateOut ) ;
162
161
163
162
// Combine the two sets of input columns.
164
- inputColumns = _bindings . InputSchema . Where ( col => activeInput [ col . Index ] || predicateIn ( col . Index ) ) ;
163
+ inputColumns = _bindings . InputSchema . Where ( col => activeInput [ col . Index ] || predicateIn ( col . Index ) ) ;
165
164
166
165
return active ;
167
- }
166
+ }
168
167
169
168
private Func < int , bool > GetActiveOutputColumns ( bool [ ] active )
170
- {
169
+ {
171
170
Contracts . AssertValue ( active ) ;
172
171
Contracts . Assert ( active . Length == _bindings . Schema . Count ) ;
173
172
@@ -177,26 +176,26 @@ private Func<int, bool> GetActiveOutputColumns(bool[] active)
177
176
Contracts . Assert ( 0 <= col && col < _bindings . AddedColumnIndices . Count ) ;
178
177
return 0 <= col && col < _bindings . AddedColumnIndices . Count && active [ _bindings . AddedColumnIndices [ col ] ] ;
179
178
} ;
180
- }
179
+ }
181
180
182
181
protected override bool ? ShouldUseParallelCursors ( Func < int , bool > predicate )
183
- {
182
+ {
184
183
Host . AssertValue ( predicate , "predicate" ) ;
185
184
if ( _bindings . AddedColumnIndices . Any ( predicate ) )
186
185
return true ;
187
186
return null ;
188
- }
187
+ }
189
188
190
189
protected override RowCursor GetRowCursorCore ( IEnumerable < Schema . Column > columnsNeeded , Random rand = null )
191
- {
190
+ {
192
191
var predicate = RowCursorUtils . FromColumnsToPredicate ( columnsNeeded , OutputSchema ) ;
193
192
var active = GetActive ( predicate , out IEnumerable < Schema . Column > inputCols ) ;
194
193
195
194
return new Cursor ( Host , Source . GetRowCursor ( inputCols , rand ) , this , active ) ;
196
- }
195
+ }
197
196
198
197
public override RowCursor [ ] GetRowCursorSet ( IEnumerable < Schema . Column > columnsNeeded , int n , Random rand = null )
199
- {
198
+ {
200
199
Host . CheckValueOrNull ( rand ) ;
201
200
202
201
var predicate = RowCursorUtils . FromColumnsToPredicate ( columnsNeeded , OutputSchema ) ;
@@ -213,89 +212,89 @@ public override RowCursor[] GetRowCursorSet(IEnumerable<Schema.Column> columnsNe
213
212
for ( int i = 0 ; i < inputs . Length ; i ++ )
214
213
cursors [ i ] = new Cursor ( Host , inputs [ i ] , this , active ) ;
215
214
return cursors ;
216
- }
215
+ }
217
216
218
217
void ISaveAsOnnx . SaveAsOnnx ( OnnxContext ctx )
219
- {
218
+ {
220
219
Host . CheckValue ( ctx , nameof ( ctx ) ) ;
221
220
if ( _mapper is ISaveAsOnnx onnx )
222
- {
221
+ {
223
222
Host . Check ( onnx . CanSaveOnnx ( ctx ) , "Cannot be saved as ONNX." ) ;
224
223
onnx . SaveAsOnnx ( ctx ) ;
224
+ }
225
225
}
226
- }
227
226
228
227
void ISaveAsPfa . SaveAsPfa ( BoundPfaContext ctx )
229
- {
228
+ {
230
229
Host . CheckValue ( ctx , nameof ( ctx ) ) ;
231
230
if ( _mapper is ISaveAsPfa pfa )
232
- {
231
+ {
233
232
Host . Check ( pfa . CanSavePfa , "Cannot be saved as PFA." ) ;
234
233
pfa . SaveAsPfa ( ctx ) ;
234
+ }
235
235
}
236
- }
237
236
238
237
/// <summary>
239
238
/// Given a set of output columns, return the input columns that are needed to generate those output columns.
240
239
/// </summary>
241
240
IEnumerable < Schema . Column > IRowToRowMapper . GetDependencies ( IEnumerable < Schema . Column > dependingColumns )
242
- {
241
+ {
243
242
var predicate = RowCursorUtils . FromColumnsToPredicate ( dependingColumns , OutputSchema ) ;
244
- GetActive ( predicate , out IEnumerable < Schema . Column > inputColumns ) ;
243
+ GetActive ( predicate , out var inputColumns ) ;
245
244
return inputColumns ;
246
- }
245
+ }
247
246
248
247
public Schema InputSchema => Source . Schema ;
249
248
250
249
public Row GetRow ( Row input , Func < int , bool > active )
251
- {
250
+ {
252
251
Host . CheckValue ( input , nameof ( input ) ) ;
253
252
Host . CheckValue ( active , nameof ( active ) ) ;
254
253
Host . Check ( input . Schema == Source . Schema , "Schema of input row must be the same as the schema the mapper is bound to" ) ;
255
254
256
255
using ( var ch = Host . Start ( "GetEntireRow" ) )
257
- {
256
+ {
258
257
var activeArr = new bool [ OutputSchema . Count ] ;
259
258
for ( int i = 0 ; i < OutputSchema . Count ; i ++ )
260
259
activeArr [ i ] = active ( i ) ;
261
260
var pred = GetActiveOutputColumns ( activeArr ) ;
262
261
var getters = _mapper . CreateGetters ( input , pred , out Action disp ) ;
263
262
return new RowImpl ( input , this , OutputSchema , getters , disp ) ;
263
+ }
264
264
}
265
- }
266
265
267
266
IDataTransform ITransformTemplate . ApplyToData ( IHostEnvironment env , IDataView newSource )
268
- {
267
+ {
269
268
Contracts . CheckValue ( env , nameof ( env ) ) ;
270
269
271
270
Contracts . CheckValue ( newSource , nameof ( newSource ) ) ;
272
271
if ( _mapperFactory != null )
273
- {
272
+ {
274
273
var newMapper = _mapperFactory ( newSource . Schema ) ;
275
274
return new RowToRowMapperTransform ( env . Register ( nameof ( RowToRowMapperTransform ) ) , newSource , newMapper , _mapperFactory ) ;
276
- }
275
+ }
277
276
// Revert to serialization. This was how it worked in all the cases, now it's only when we can't re-create the mapper.
278
277
using ( var stream = new MemoryStream ( ) )
279
- {
280
- using ( var rep = RepositoryWriter . CreateNew ( stream , env ) )
281
278
{
279
+ using ( var rep = RepositoryWriter . CreateNew ( stream , env ) )
280
+ {
282
281
ModelSaveContext . SaveModel ( rep , this , "model" ) ;
283
282
rep . Commit ( ) ;
284
- }
283
+ }
285
284
286
285
stream . Position = 0 ;
287
286
using ( var rep = RepositoryReader . Open ( stream , env ) )
288
- {
287
+ {
289
288
IDataTransform newData ;
290
289
ModelLoadContext . LoadModel < IDataTransform , SignatureLoadDataTransform > ( env ,
291
290
out newData , rep , "model" , newSource ) ;
292
291
return newData ;
292
+ }
293
293
}
294
294
}
295
- }
296
295
297
296
private sealed class RowImpl : WrappingRow
298
- {
297
+ {
299
298
private readonly Delegate [ ] _getters ;
300
299
private readonly RowToRowMapperTransform _parent ;
301
300
private readonly Action _disposer ;
@@ -304,21 +303,21 @@ private sealed class RowImpl : WrappingRow
304
303
305
304
public RowImpl ( Row input , RowToRowMapperTransform parent , Schema schema , Delegate [ ] getters , Action disposer )
306
305
: base ( input )
307
- {
306
+ {
308
307
_parent = parent ;
309
308
Schema = schema ;
310
309
_getters = getters ;
311
310
_disposer = disposer ;
312
- }
311
+ }
313
312
314
313
protected override void DisposeCore ( bool disposing )
315
- {
314
+ {
316
315
if ( disposing )
317
316
_disposer ? . Invoke ( ) ;
318
- }
317
+ }
319
318
320
319
public override ValueGetter < TValue > GetGetter < TValue > ( int col )
321
- {
320
+ {
322
321
bool isSrc ;
323
322
int index = _parent . _bindings . MapColumnIndex ( out isSrc , col ) ;
324
323
if ( isSrc )
@@ -329,20 +328,20 @@ public override ValueGetter<TValue> GetGetter<TValue>(int col)
329
328
if ( fn == null )
330
329
throw Contracts . Except ( "Invalid TValue in GetGetter: '{0}'" , typeof ( TValue ) ) ;
331
330
return fn ;
332
- }
331
+ }
333
332
334
333
public override bool IsColumnActive ( int col )
335
- {
334
+ {
336
335
bool isSrc ;
337
336
int index = _parent . _bindings . MapColumnIndex ( out isSrc , col ) ;
338
337
if ( isSrc )
339
338
return Input . IsColumnActive ( ( index ) ) ;
340
339
return _getters [ index ] != null ;
340
+ }
341
341
}
342
- }
343
342
344
343
private sealed class Cursor : SynchronizedCursorBase
345
- {
344
+ {
346
345
private readonly Delegate [ ] _getters ;
347
346
private readonly bool [ ] _active ;
348
347
private readonly ColumnBindings _bindings ;
@@ -353,21 +352,21 @@ private sealed class Cursor : SynchronizedCursorBase
353
352
354
353
public Cursor ( IChannelProvider provider , RowCursor input , RowToRowMapperTransform parent , bool [ ] active )
355
354
: base ( provider , input )
356
- {
355
+ {
357
356
var pred = parent . GetActiveOutputColumns ( active ) ;
358
357
_getters = parent . _mapper . CreateGetters ( input , pred , out _disposer ) ;
359
358
_active = active ;
360
359
_bindings = parent . _bindings ;
361
- }
360
+ }
362
361
363
362
public override bool IsColumnActive ( int col )
364
- {
363
+ {
365
364
Ch . Check ( 0 <= col && col < _bindings . Schema . Count ) ;
366
365
return _active [ col ] ;
367
- }
366
+ }
368
367
369
368
public override ValueGetter < TValue > GetGetter < TValue > ( int col )
370
- {
369
+ {
371
370
Ch . Check ( IsColumnActive ( col ) ) ;
372
371
373
372
bool isSrc ;
@@ -382,22 +381,22 @@ public override ValueGetter<TValue> GetGetter<TValue>(int col)
382
381
if ( fn == null )
383
382
throw Ch . Except ( "Invalid TValue in GetGetter: '{0}'" , typeof ( TValue ) ) ;
384
383
return fn ;
385
- }
384
+ }
386
385
387
386
protected override void Dispose ( bool disposing )
388
- {
387
+ {
389
388
if ( _disposed )
390
389
return ;
391
390
if ( disposing )
392
391
_disposer ? . Invoke ( ) ;
393
392
_disposed = true ;
394
393
base . Dispose ( disposing ) ;
394
+ }
395
395
}
396
- }
397
396
398
397
internal ITransformer GetTransformer ( )
399
- {
398
+ {
400
399
return _mapper . GetTransformer ( ) ;
400
+ }
401
401
}
402
402
}
403
- }
0 commit comments