@@ -27,7 +27,7 @@ namespace Microsoft.ML.Data
27
27
/// </summary>
28
28
[ BestFriend ]
29
29
internal interface IRowMapper : ICanSaveModel
30
- {
30
+ {
31
31
/// <summary>
32
32
/// Returns the input columns needed for the requested output columns.
33
33
/// </summary>
@@ -54,7 +54,10 @@ internal interface IRowMapper : ICanSaveModel
54
54
/// Returns parent transfomer which uses this mapper.
55
55
/// </summary>
56
56
ITransformer GetTransformer ( ) ;
57
- }
57
+ }
58
+
59
+ [ BestFriend ]
60
+ internal delegate void SignatureLoadRowMapper ( ModelLoadContext ctx , Schema schema ) ;
58
61
59
62
/// <summary>
60
63
/// This class is a transform that can add any number of output columns, that depend on any number of input columns.
@@ -64,7 +67,7 @@ internal interface IRowMapper : ICanSaveModel
64
67
[ BestFriend ]
65
68
internal sealed class RowToRowMapperTransform : RowToRowTransformBase , IRowToRowMapper ,
66
69
ITransformCanSaveOnnx , ITransformCanSavePfa , ITransformTemplate
67
- {
70
+ {
68
71
private readonly IRowMapper _mapper ;
69
72
private readonly ColumnBindings _bindings ;
70
73
@@ -74,15 +77,15 @@ internal sealed class RowToRowMapperTransform : RowToRowTransformBase, IRowToRow
74
77
public const string RegistrationName = "RowToRowMapperTransform" ;
75
78
public const string LoaderSignature = "RowToRowMapper" ;
76
79
private static VersionInfo GetVersionInfo ( )
77
- {
80
+ {
78
81
return new VersionInfo (
79
82
modelSignature : "ROW MPPR" ,
80
83
verWrittenCur : 0x00010001 , // Initial
81
84
verReadableCur : 0x00010001 ,
82
85
verWeCanReadBack : 0x00010001 ,
83
86
loaderSignature : LoaderSignature ,
84
87
loaderAssemblyName : typeof ( RowToRowMapperTransform ) . Assembly . FullName ) ;
85
- }
88
+ }
86
89
87
90
public override Schema OutputSchema => _bindings . Schema ;
88
91
@@ -92,44 +95,44 @@ private static VersionInfo GetVersionInfo()
92
95
93
96
public RowToRowMapperTransform ( IHostEnvironment env , IDataView input , IRowMapper mapper , Func < Schema , IRowMapper > mapperFactory )
94
97
: base ( env , RegistrationName , input )
95
- {
98
+ {
96
99
Contracts . CheckValue ( mapper , nameof ( mapper ) ) ;
97
100
Contracts . CheckValueOrNull ( mapperFactory ) ;
98
101
_mapper = mapper ;
99
102
_mapperFactory = mapperFactory ;
100
103
_bindings = new ColumnBindings ( input . Schema , mapper . GetOutputColumns ( ) ) ;
101
- }
104
+ }
102
105
103
106
[ BestFriend ]
104
107
internal static Schema GetOutputSchema ( Schema inputSchema , IRowMapper mapper )
105
- {
108
+ {
106
109
Contracts . CheckValue ( inputSchema , nameof ( inputSchema ) ) ;
107
110
Contracts . CheckValue ( mapper , nameof ( mapper ) ) ;
108
111
return new ColumnBindings ( inputSchema , mapper . GetOutputColumns ( ) ) . Schema ;
109
- }
112
+ }
110
113
111
114
private RowToRowMapperTransform ( IHost host , ModelLoadContext ctx , IDataView input )
112
115
: base ( host , input )
113
- {
116
+ {
114
117
// *** Binary format ***
115
118
// _mapper
116
119
117
120
ctx . LoadModel < IRowMapper , SignatureLoadRowMapper > ( host , out _mapper , "Mapper" , input . Schema ) ;
118
121
_bindings = new ColumnBindings ( input . Schema , _mapper . GetOutputColumns ( ) ) ;
119
- }
122
+ }
120
123
121
124
public static RowToRowMapperTransform Create ( IHostEnvironment env , ModelLoadContext ctx , IDataView input )
122
- {
125
+ {
123
126
Contracts . CheckValue ( env , nameof ( env ) ) ;
124
127
var h = env . Register ( RegistrationName ) ;
125
128
h . CheckValue ( ctx , nameof ( ctx ) ) ;
126
129
ctx . CheckAtModel ( GetVersionInfo ( ) ) ;
127
130
h . CheckValue ( input , nameof ( input ) ) ;
128
131
return h . Apply ( "Loading Model" , ch => new RowToRowMapperTransform ( h , ctx , input ) ) ;
129
- }
132
+ }
130
133
131
134
private protected override void SaveModel ( ModelSaveContext ctx )
132
- {
135
+ {
133
136
Host . CheckValue ( ctx , nameof ( ctx ) ) ;
134
137
ctx . CheckAtModel ( ) ;
135
138
ctx . SetVersionInfo ( GetVersionInfo ( ) ) ;
@@ -138,14 +141,14 @@ private protected override void SaveModel(ModelSaveContext ctx)
138
141
// _mapper
139
142
140
143
ctx . SaveModel ( _mapper , "Mapper" ) ;
141
- }
144
+ }
142
145
143
146
/// <summary>
144
147
/// Produces the set of active columns for the data view (as a bool[] of length bindings.ColumnCount),
145
148
/// and the needed active input columns, given a predicate for the needed active output columns.
146
149
/// </summary>
147
150
private bool [ ] GetActive ( Func < int , bool > predicate , out IEnumerable < Schema . Column > inputColumns )
148
- {
151
+ {
149
152
int n = _bindings . Schema . Count ;
150
153
var active = Utils . BuildArray ( n , predicate ) ;
151
154
Contracts . Assert ( active . Length == n ) ;
@@ -163,10 +166,10 @@ private bool[] GetActive(Func<int, bool> predicate, out IEnumerable<Schema.Colum
163
166
inputColumns = _bindings . InputSchema . Where ( col => activeInput [ col . Index ] || predicateIn ( col . Index ) ) ;
164
167
165
168
return active ;
166
- }
169
+ }
167
170
168
171
private Func < int , bool > GetActiveOutputColumns ( bool [ ] active )
169
- {
172
+ {
170
173
Contracts . AssertValue ( active ) ;
171
174
Contracts . Assert ( active . Length == _bindings . Schema . Count ) ;
172
175
@@ -176,26 +179,26 @@ private Func<int, bool> GetActiveOutputColumns(bool[] active)
176
179
Contracts . Assert ( 0 <= col && col < _bindings . AddedColumnIndices . Count ) ;
177
180
return 0 <= col && col < _bindings . AddedColumnIndices . Count && active [ _bindings . AddedColumnIndices [ col ] ] ;
178
181
} ;
179
- }
182
+ }
180
183
181
184
protected override bool ? ShouldUseParallelCursors ( Func < int , bool > predicate )
182
- {
185
+ {
183
186
Host . AssertValue ( predicate , "predicate" ) ;
184
187
if ( _bindings . AddedColumnIndices . Any ( predicate ) )
185
188
return true ;
186
189
return null ;
187
- }
190
+ }
188
191
189
192
protected override RowCursor GetRowCursorCore ( IEnumerable < Schema . Column > columnsNeeded , Random rand = null )
190
- {
193
+ {
191
194
var predicate = RowCursorUtils . FromColumnsToPredicate ( columnsNeeded , OutputSchema ) ;
192
195
var active = GetActive ( predicate , out IEnumerable < Schema . Column > inputCols ) ;
193
196
194
197
return new Cursor ( Host , Source . GetRowCursor ( inputCols , rand ) , this , active ) ;
195
- }
198
+ }
196
199
197
200
public override RowCursor [ ] GetRowCursorSet ( IEnumerable < Schema . Column > columnsNeeded , int n , Random rand = null )
198
- {
201
+ {
199
202
Host . CheckValueOrNull ( rand ) ;
200
203
201
204
var predicate = RowCursorUtils . FromColumnsToPredicate ( columnsNeeded , OutputSchema ) ;
@@ -212,89 +215,89 @@ public override RowCursor[] GetRowCursorSet(IEnumerable<Schema.Column> columnsNe
212
215
for ( int i = 0 ; i < inputs . Length ; i ++ )
213
216
cursors [ i ] = new Cursor ( Host , inputs [ i ] , this , active ) ;
214
217
return cursors ;
215
- }
218
+ }
216
219
217
220
void ISaveAsOnnx . SaveAsOnnx ( OnnxContext ctx )
218
- {
221
+ {
219
222
Host . CheckValue ( ctx , nameof ( ctx ) ) ;
220
223
if ( _mapper is ISaveAsOnnx onnx )
221
- {
224
+ {
222
225
Host . Check ( onnx . CanSaveOnnx ( ctx ) , "Cannot be saved as ONNX." ) ;
223
226
onnx . SaveAsOnnx ( ctx ) ;
224
- }
225
227
}
228
+ }
226
229
227
230
void ISaveAsPfa . SaveAsPfa ( BoundPfaContext ctx )
228
- {
231
+ {
229
232
Host . CheckValue ( ctx , nameof ( ctx ) ) ;
230
233
if ( _mapper is ISaveAsPfa pfa )
231
- {
234
+ {
232
235
Host . Check ( pfa . CanSavePfa , "Cannot be saved as PFA." ) ;
233
236
pfa . SaveAsPfa ( ctx ) ;
234
- }
235
237
}
238
+ }
236
239
237
240
/// <summary>
238
241
/// Given a set of output columns, return the input columns that are needed to generate those output columns.
239
242
/// </summary>
240
243
IEnumerable < Schema . Column > IRowToRowMapper . GetDependencies ( IEnumerable < Schema . Column > dependingColumns )
241
- {
244
+ {
242
245
var predicate = RowCursorUtils . FromColumnsToPredicate ( dependingColumns , OutputSchema ) ;
243
246
GetActive ( predicate , out var inputColumns ) ;
244
247
return inputColumns ;
245
- }
248
+ }
246
249
247
250
public Schema InputSchema => Source . Schema ;
248
251
249
252
public Row GetRow ( Row input , Func < int , bool > active )
250
- {
253
+ {
251
254
Host . CheckValue ( input , nameof ( input ) ) ;
252
255
Host . CheckValue ( active , nameof ( active ) ) ;
253
256
Host . Check ( input . Schema == Source . Schema , "Schema of input row must be the same as the schema the mapper is bound to" ) ;
254
257
255
258
using ( var ch = Host . Start ( "GetEntireRow" ) )
256
- {
259
+ {
257
260
var activeArr = new bool [ OutputSchema . Count ] ;
258
261
for ( int i = 0 ; i < OutputSchema . Count ; i ++ )
259
262
activeArr [ i ] = active ( i ) ;
260
263
var pred = GetActiveOutputColumns ( activeArr ) ;
261
264
var getters = _mapper . CreateGetters ( input , pred , out Action disp ) ;
262
265
return new RowImpl ( input , this , OutputSchema , getters , disp ) ;
263
- }
264
266
}
267
+ }
265
268
266
269
IDataTransform ITransformTemplate . ApplyToData ( IHostEnvironment env , IDataView newSource )
267
- {
270
+ {
268
271
Contracts . CheckValue ( env , nameof ( env ) ) ;
269
272
270
273
Contracts . CheckValue ( newSource , nameof ( newSource ) ) ;
271
274
if ( _mapperFactory != null )
272
- {
275
+ {
273
276
var newMapper = _mapperFactory ( newSource . Schema ) ;
274
277
return new RowToRowMapperTransform ( env . Register ( nameof ( RowToRowMapperTransform ) ) , newSource , newMapper , _mapperFactory ) ;
275
- }
278
+ }
276
279
// Revert to serialization. This was how it worked in all the cases, now it's only when we can't re-create the mapper.
277
280
using ( var stream = new MemoryStream ( ) )
278
- {
281
+ {
279
282
using ( var rep = RepositoryWriter . CreateNew ( stream , env ) )
280
- {
283
+ {
281
284
ModelSaveContext . SaveModel ( rep , this , "model" ) ;
282
285
rep . Commit ( ) ;
283
- }
286
+ }
284
287
285
288
stream . Position = 0 ;
286
289
using ( var rep = RepositoryReader . Open ( stream , env ) )
287
- {
290
+ {
288
291
IDataTransform newData ;
289
292
ModelLoadContext . LoadModel < IDataTransform , SignatureLoadDataTransform > ( env ,
290
293
out newData , rep , "model" , newSource ) ;
291
294
return newData ;
292
- }
293
295
}
294
296
}
297
+ }
295
298
296
299
private sealed class RowImpl : WrappingRow
297
- {
300
+ {
298
301
private readonly Delegate [ ] _getters ;
299
302
private readonly RowToRowMapperTransform _parent ;
300
303
private readonly Action _disposer ;
@@ -303,21 +306,21 @@ private sealed class RowImpl : WrappingRow
303
306
304
307
public RowImpl ( Row input , RowToRowMapperTransform parent , Schema schema , Delegate [ ] getters , Action disposer )
305
308
: base ( input )
306
- {
309
+ {
307
310
_parent = parent ;
308
311
Schema = schema ;
309
312
_getters = getters ;
310
313
_disposer = disposer ;
311
- }
314
+ }
312
315
313
316
protected override void DisposeCore ( bool disposing )
314
- {
317
+ {
315
318
if ( disposing )
316
319
_disposer ? . Invoke ( ) ;
317
- }
320
+ }
318
321
319
322
public override ValueGetter < TValue > GetGetter < TValue > ( int col )
320
- {
323
+ {
321
324
bool isSrc ;
322
325
int index = _parent . _bindings . MapColumnIndex ( out isSrc , col ) ;
323
326
if ( isSrc )
@@ -328,20 +331,20 @@ public override ValueGetter<TValue> GetGetter<TValue>(int col)
328
331
if ( fn == null )
329
332
throw Contracts . Except ( "Invalid TValue in GetGetter: '{0}'" , typeof ( TValue ) ) ;
330
333
return fn ;
331
- }
334
+ }
332
335
333
336
public override bool IsColumnActive ( int col )
334
- {
337
+ {
335
338
bool isSrc ;
336
339
int index = _parent . _bindings . MapColumnIndex ( out isSrc , col ) ;
337
340
if ( isSrc )
338
341
return Input . IsColumnActive ( ( index ) ) ;
339
342
return _getters [ index ] != null ;
340
- }
341
343
}
344
+ }
342
345
343
346
private sealed class Cursor : SynchronizedCursorBase
344
- {
347
+ {
345
348
private readonly Delegate [ ] _getters ;
346
349
private readonly bool [ ] _active ;
347
350
private readonly ColumnBindings _bindings ;
@@ -352,21 +355,21 @@ private sealed class Cursor : SynchronizedCursorBase
352
355
353
356
public Cursor ( IChannelProvider provider , RowCursor input , RowToRowMapperTransform parent , bool [ ] active )
354
357
: base ( provider , input )
355
- {
358
+ {
356
359
var pred = parent . GetActiveOutputColumns ( active ) ;
357
360
_getters = parent . _mapper . CreateGetters ( input , pred , out _disposer ) ;
358
361
_active = active ;
359
362
_bindings = parent . _bindings ;
360
- }
363
+ }
361
364
362
365
public override bool IsColumnActive ( int col )
363
- {
366
+ {
364
367
Ch . Check ( 0 <= col && col < _bindings . Schema . Count ) ;
365
368
return _active [ col ] ;
366
- }
369
+ }
367
370
368
371
public override ValueGetter < TValue > GetGetter < TValue > ( int col )
369
- {
372
+ {
370
373
Ch . Check ( IsColumnActive ( col ) ) ;
371
374
372
375
bool isSrc ;
@@ -381,22 +384,22 @@ public override ValueGetter<TValue> GetGetter<TValue>(int col)
381
384
if ( fn == null )
382
385
throw Ch . Except ( "Invalid TValue in GetGetter: '{0}'" , typeof ( TValue ) ) ;
383
386
return fn ;
384
- }
387
+ }
385
388
386
389
protected override void Dispose ( bool disposing )
387
- {
390
+ {
388
391
if ( _disposed )
389
392
return ;
390
393
if ( disposing )
391
394
_disposer ? . Invoke ( ) ;
392
395
_disposed = true ;
393
396
base . Dispose ( disposing ) ;
394
- }
395
397
}
398
+ }
396
399
397
400
internal ITransformer GetTransformer ( )
398
- {
401
+ {
399
402
return _mapper . GetTransformer ( ) ;
400
- }
401
403
}
402
404
}
405
+ }
0 commit comments