Skip to content

Commit 683711a

Browse files
authored
Replace predicate with an IEnumerable<DataViewSchema.Column> for IRowToRowMapper.GetRow and ISchemaBoundRowMapper.GetRow (#2796)
* replacing the `Func<int, bool> active` predicate with an `IEnumerable<DataViewSchema.Column> activeColumns` from the IRowToRowMapper.GetRow and the ISchemaBoundRowMapper.GetRow. renaming col to `columnIndex` in the `DataViewRow.IsColumnActive` and `DataViewRow.GetGetter<TValue>` * renaming col to columnIndex in all the overrides of IDataView.GetGetter and DataView.IsColumnActive Documentation. * Making sure GetRow is an explicit implementation everywhere. * addressing PR review, and changing the GetGetter method signature, to take a column, rather than column index. * changing the signature for IsColumnActive to take a DataViewSchema.Column instead of the column index.
1 parent f09b25f commit 683711a

File tree

144 files changed

+1749
-1235
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

144 files changed

+1749
-1235
lines changed

src/Microsoft.Data.DataView/DataViewSchema.cs

+6-6
Original file line numberDiff line numberDiff line change
@@ -217,14 +217,14 @@ private void CheckGetter<TValue>(Delegate getter)
217217
/// <summary>
218218
/// Get a getter delegate for one value of the annotations row.
219219
/// </summary>
220-
public ValueGetter<TValue> GetGetter<TValue>(int col)
220+
public ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column)
221221
{
222-
if (!(0 <= col && col < Schema.Count))
223-
throw new ArgumentOutOfRangeException(nameof(col));
224-
var typedGetter = _getters[col] as ValueGetter<TValue>;
222+
if (column.Index >= _getters.Length)
223+
throw new ArgumentException(nameof(column));
224+
var typedGetter = _getters[column.Index] as ValueGetter<TValue>;
225225
if (typedGetter == null)
226226
{
227-
Debug.Assert(_getters[col] != null);
227+
Debug.Assert(_getters[column.Index] != null);
228228
throw new InvalidOperationException($"Invalid call to '{nameof(GetGetter)}'");
229229
}
230230
return typedGetter;
@@ -238,7 +238,7 @@ public void GetValue<TValue>(string kind, ref TValue value)
238238
var column = Schema.GetColumnOrNull(kind);
239239
if (column == null)
240240
throw new InvalidOperationException($"Invalid call to '{nameof(GetValue)}'");
241-
GetGetter<TValue>(column.Value.Index)(ref value);
241+
GetGetter<TValue>(column.Value)(ref value);
242242
}
243243

244244
public override string ToString() => string.Join(", ", Schema.Select(x => x.Name));

src/Microsoft.Data.DataView/IDataView.cs

+5-3
Original file line numberDiff line numberDiff line change
@@ -134,14 +134,16 @@ public abstract class DataViewRow : IDisposable
134134
/// <summary>
135135
/// Returns whether the given column is active in this row.
136136
/// </summary>
137-
public abstract bool IsColumnActive(int col);
137+
public abstract bool IsColumnActive(DataViewSchema.Column column);
138138

139139
/// <summary>
140-
/// Returns a value getter delegate to fetch the given column value from the row.
140+
/// Returns a value getter delegate to fetch the value of the given <paramref name="column"/>, from the row.
141141
/// This throws if the column is not active in this row, or if the type
142142
/// <typeparamref name="TValue"/> differs from this column's type.
143143
/// </summary>
144-
public abstract ValueGetter<TValue> GetGetter<TValue>(int col);
144+
/// <typeparam name="TValue"> is the column's content type.</typeparam>
145+
/// <param name="column"> is the output column whose getter should be returned.</param>
146+
public abstract ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column);
145147

146148
/// <summary>
147149
/// Gets a <see cref="Schema"/>, which provides name and type information for variables

src/Microsoft.Data.DataView/SchemaDebuggerProxy.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,16 @@ private static List<KeyValuePair<string, object>> BuildValues(DataViewSchema.Ann
4141
foreach (var column in annotations.Schema)
4242
{
4343
var name = column.Name;
44-
var value = Utils.MarshalInvoke(GetValue<int>, column.Type.RawType, annotations, column.Index);
44+
var value = Utils.MarshalInvoke(GetValue<DataViewSchema.Column>, column.Type.RawType, annotations, column);
4545
result.Add(new KeyValuePair<string, object>(name, value));
4646
}
4747
return result;
4848
}
4949

50-
private static object GetValue<T>(DataViewSchema.Annotations annotations, int columnIndex)
50+
private static object GetValue<T>(DataViewSchema.Annotations annotations, DataViewSchema.Column column)
5151
{
5252
T value = default;
53-
annotations.GetGetter<T>(columnIndex)(ref value);
53+
annotations.GetGetter<T>(column)(ref value);
5454
return value;
5555
}
5656
}

src/Microsoft.ML.Core/Data/AnnotationUtils.cs

+15-2
Original file line numberDiff line numberDiff line change
@@ -460,9 +460,22 @@ public AnnotationRow(DataViewSchema.Annotations annotations)
460460
public override DataViewSchema Schema => _annotations.Schema;
461461
public override long Position => 0;
462462
public override long Batch => 0;
463-
public override ValueGetter<TValue> GetGetter<TValue>(int col) => _annotations.GetGetter<TValue>(col);
463+
464+
/// <summary>
465+
/// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
466+
/// This throws if the column is not active in this row, or if the type
467+
/// <typeparamref name="TValue"/> differs from this column's type.
468+
/// </summary>
469+
/// <typeparam name="TValue"> is the column's content type.</typeparam>
470+
/// <param name="column"> is the output column whose getter should be returned.</param>
471+
public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column) => _annotations.GetGetter<TValue>(column);
472+
464473
public override ValueGetter<DataViewRowId> GetIdGetter() => (ref DataViewRowId dst) => dst = default;
465-
public override bool IsColumnActive(int col) => true;
474+
475+
/// <summary>
476+
/// Returns whether the given column is active in this row.
477+
/// </summary>
478+
public override bool IsColumnActive(DataViewSchema.Column column) => true;
466479
}
467480

468481
/// <summary>

src/Microsoft.ML.Core/Data/IRowToRowMapper.cs

+2-3
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@ public interface IRowToRowMapper
3535

3636
/// <summary>
3737
/// Get an <see cref="DataViewRow"/> with the indicated active columns, based on the input <paramref name="input"/>.
38-
/// The active columns are those for which <paramref name="active"/> returns true. Getting values on inactive
39-
/// columns of the returned row will throw. Null predicates are disallowed.
38+
/// Getting values on inactive columns of the returned row will throw.
4039
///
4140
/// The <see cref="DataViewRow.Schema"/> of <paramref name="input"/> should be the same object as
4241
/// <see cref="InputSchema"/>. Implementors of this method should throw if that is not the case. Conversely,
@@ -48,6 +47,6 @@ public interface IRowToRowMapper
4847
/// The output <see cref="DataViewRow"/> values are re-computed when requested through the getters. Also, the returned
4948
/// <see cref="DataViewRow"/> will dispose <paramref name="input"/> when it is disposed.
5049
/// </summary>
51-
DataViewRow GetRow(DataViewRow input, Func<int, bool> active);
50+
DataViewRow GetRow(DataViewRow input, IEnumerable<DataViewSchema.Column> activeColumns);
5251
}
5352
}

src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs

+2-3
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,7 @@ internal interface ISchemaBoundRowMapper : ISchemaBoundMapper
7373

7474
/// <summary>
7575
/// Get an <see cref="DataViewRow"/> with the indicated active columns, based on the input <paramref name="input"/>.
76-
/// The active columns are those for which <paramref name="active"/> returns true. Getting values on inactive
77-
/// columns of the returned row will throw. Null predicates are disallowed.
76+
/// Getting values on inactive columns of the returned row will throw.
7877
///
7978
/// The <see cref="DataViewRow.Schema"/> of <paramref name="input"/> should be the same object as
8079
/// <see cref="InputSchema"/>. Implementors of this method should throw if that is not the case. Conversely,
@@ -86,6 +85,6 @@ internal interface ISchemaBoundRowMapper : ISchemaBoundMapper
8685
/// The output <see cref="DataViewRow"/> values are re-computed when requested through the getters. Also, the returned
8786
/// <see cref="DataViewRow"/> will dispose <paramref name="input"/> when it is disposed.
8887
/// </summary>
89-
DataViewRow GetRow(DataViewRow input, Func<int, bool> active);
88+
DataViewRow GetRow(DataViewRow input, IEnumerable<DataViewSchema.Column> activeColumns);
9089
}
9190
}

src/Microsoft.ML.Core/Data/LinkedRowRootCursorBase.cs

+16-6
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace Microsoft.ML.Data
99
/// <summary>
1010
/// A base class for a <see cref="DataViewRowCursor"/> that has an input cursor, but still needs to do work on
1111
/// <see cref="DataViewRowCursor.MoveNext"/>. Note that the default
12-
/// <see cref="LinkedRowRootCursorBase.GetGetter{TValue}(int)"/> assumes that each input column is exposed as an
12+
/// <see cref="LinkedRowRootCursorBase.GetGetter{TValue}(DataViewSchema.Column)"/> assumes that each input column is exposed as an
1313
/// output column with the same column index.
1414
/// </summary>
1515
[BestFriend]
@@ -29,15 +29,25 @@ protected LinkedRowRootCursorBase(IChannelProvider provider, DataViewRowCursor i
2929
Schema = schema;
3030
}
3131

32-
public sealed override bool IsColumnActive(int col)
32+
/// <summary>
33+
/// Returns whether the given column is active in this row.
34+
/// </summary>
35+
public sealed override bool IsColumnActive(DataViewSchema.Column column)
3336
{
34-
Ch.Check(0 <= col && col < Schema.Count);
35-
return _active == null || _active[col];
37+
Ch.Check(column.Index < Schema.Count);
38+
return _active == null || _active[column.Index];
3639
}
3740

38-
public override ValueGetter<TValue> GetGetter<TValue>(int col)
41+
/// <summary>
42+
/// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
43+
/// This throws if the column is not active in this row, or if the type
44+
/// <typeparamref name="TValue"/> differs from this column's type.
45+
/// </summary>
46+
/// <typeparam name="TValue"> is the column's content type.</typeparam>
47+
/// <param name="column"> is the output column whose getter should be returned.</param>
48+
public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column)
3949
{
40-
return Input.GetGetter<TValue>(col);
50+
return Input.GetGetter<TValue>(column);
4151
}
4252
}
4353
}

src/Microsoft.ML.Core/Utilities/Utils.cs

+35
Original file line numberDiff line numberDiff line change
@@ -776,6 +776,41 @@ public static T[] BuildArray<T>(int length, Func<int, T> func)
776776
return result;
777777
}
778778

779+
/// <summary>
780+
/// Given a predicate, over a range of values defined by a limit calculate
781+
/// first the values for which that predicate was true, and second an inverse
782+
/// map.
783+
/// </summary>
784+
/// <param name="schema">The input schema where the predicate can check if columns are active.</param>
785+
/// <param name="pred">The predicate to test for various value</param>
786+
/// <param name="map">An ascending array of values from 0 inclusive
787+
/// to <paramref name="schema.Count"/> exclusive, holding all values for which
788+
/// <paramref name="pred"/> is true</param>
789+
/// <param name="invMap">Forms an inverse mapping of <paramref name="map"/>,
790+
/// so that <c><paramref name="invMap"/>[<paramref name="map"/>[i]] == i</c>,
791+
/// and for other entries not appearing in <paramref name="map"/>,
792+
/// <c><paramref name="invMap"/>[i] == -1</c></param>
793+
public static void BuildSubsetMaps(DataViewSchema schema, Func<DataViewSchema.Column, bool> pred, out int[] map, out int[] invMap)
794+
{
795+
Contracts.CheckValue(schema, nameof(schema));
796+
Contracts.Check(schema.Count > 0, nameof(schema));
797+
Contracts.CheckValue(pred, nameof(pred));
798+
// REVIEW: Better names?
799+
List<int> mapList = new List<int>();
800+
invMap = new int[schema.Count];
801+
for (int c = 0; c < schema.Count; ++c)
802+
{
803+
if (!pred(schema[c]))
804+
{
805+
invMap[c] = -1;
806+
continue;
807+
}
808+
invMap[c] = mapList.Count;
809+
mapList.Add(c);
810+
}
811+
map = mapList.ToArray();
812+
}
813+
779814
/// <summary>
780815
/// Given a predicate, over a range of values defined by a limit calculate
781816
/// first the values for which that predicate was true, and second an inverse

0 commit comments

Comments
 (0)