Skip to content

Commit f55f840

Browse files
authored
Making Schema implement ISchema explicitly (#1894)
* Removed ColumnCount and GetColumns * Removed GetColumnName and GetColumnType * Removed GetMetadataTypes and GetMetadataTypeOrNull * Removed GetMetadata and hid TryGetColumnIndex
1 parent 7037a0d commit f55f840

File tree

169 files changed

+1067
-1180
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

169 files changed

+1067
-1180
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/FeatureContributionCalculationTransform.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ public static void FeatureContributionCalculationTransform_Regression()
8585
var value = row.Features[featureOfInterest];
8686
var contribution = row.FeatureContributions[featureOfInterest];
8787
var percentContribution = 100 * contribution / row.Score;
88-
var name = data.Schema.GetColumnName(featureOfInterest + 1);
88+
var name = data.Schema[(int) (featureOfInterest + 1)].Name;
8989
var weight = weights.GetValues()[featureOfInterest];
9090

9191
Console.WriteLine("{0:0.00}\t{1:0.00}\t{2}\t{3:0.00}\t{4:0.00}\t{5:0.00}\t{6:0.00}",

docs/samples/Microsoft.ML.Samples/Dynamic/GeneralizedAdditiveModels.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ public static void RunExample()
4848
// and use a small number of bins to make it easy to visualize in the console window.
4949
// For real appplications, it is recommended to start with the default number of bins.
5050
var labelName = "MedianHomeValue";
51-
var featureNames = data.Schema.GetColumns()
52-
.Select(tuple => tuple.column.Name) // Get the column names
51+
var featureNames = data.Schema
52+
.Select(column => column.Name) // Get the column names
5353
.Where(name => name != labelName) // Drop the Label
5454
.ToArray();
5555
var pipeline = mlContext.Transforms.Concatenate("Features", featureNames)

docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ public static void PFI_Regression()
7171
// Now let's look at which features are most important to the model overall
7272
// First, we have to prepare the data:
7373
// Get the feature names as an IEnumerable
74-
var featureNames = data.Schema.GetColumns()
75-
.Select(tuple => tuple.column.Name) // Get the column names
74+
var featureNames = data.Schema
75+
.Select(column => column.Name) // Get the column names
7676
.Where(name => name != labelName) // Drop the Label
7777
.ToArray();
7878

src/Microsoft.ML.Core/Data/IEstimator.cs

+5-5
Original file line numberDiff line numberDiff line change
@@ -177,15 +177,15 @@ internal static SchemaShape Create(Schema schema)
177177
{
178178
// First create the metadata.
179179
var mCols = new List<Column>();
180-
foreach (var metaNameType in schema.GetMetadataTypes(iCol))
180+
foreach (var metaColumn in schema[iCol].Metadata.Schema)
181181
{
182-
GetColumnTypeShape(metaNameType.Value, out var mVecKind, out var mItemType, out var mIsKey);
183-
mCols.Add(new Column(metaNameType.Key, mVecKind, mItemType, mIsKey));
182+
GetColumnTypeShape(metaColumn.Type, out var mVecKind, out var mItemType, out var mIsKey);
183+
mCols.Add(new Column(metaColumn.Name, mVecKind, mItemType, mIsKey));
184184
}
185185
var metadata = mCols.Count > 0 ? new SchemaShape(mCols) : _empty;
186186
// Next create the single column.
187-
GetColumnTypeShape(schema.GetColumnType(iCol), out var vecKind, out var itemType, out var isKey);
188-
cols.Add(new Column(schema.GetColumnName(iCol), vecKind, itemType, isKey, metadata));
187+
GetColumnTypeShape(schema[iCol].Type, out var vecKind, out var itemType, out var isKey);
188+
cols.Add(new Column(schema[iCol].Name, vecKind, itemType, isKey, metadata));
189189
}
190190
}
191191
return new SchemaShape(cols);

src/Microsoft.ML.Core/Data/ISchemaBindableMapper.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ public interface IRowToRowMapper
8787

8888
/// <summary>
8989
/// Given a predicate specifying which columns are needed, return a predicate indicating which input columns are
90-
/// needed. The domain of the function is defined over the indices of the columns of <see cref="Schema.ColumnCount"/>
90+
/// needed. The domain of the function is defined over the indices of the columns of <see cref="Schema.Count"/>
9191
/// for <see cref="InputSchema"/>.
9292
/// </summary>
9393
Func<int, bool> GetDependencies(Func<int, bool> predicate);

src/Microsoft.ML.Core/Data/MetadataUtils.cs

+10-10
Original file line numberDiff line numberDiff line change
@@ -239,13 +239,13 @@ public static uint GetMaxMetadataKind(this Schema schema, out int colMax, string
239239
colMax = -1;
240240
for (int col = 0; col < schema.Count; col++)
241241
{
242-
var columnType = schema.GetMetadataTypeOrNull(metadataKind, col);
242+
var columnType = schema[col].Metadata.Schema.GetColumnOrNull(metadataKind)?.Type;
243243
if (columnType == null || !columnType.IsKey || columnType.RawKind != DataKind.U4)
244244
continue;
245245
if (filterFunc != null && !filterFunc(schema, col))
246246
continue;
247247
uint value = 0;
248-
schema.GetMetadata(metadataKind, col, ref value);
248+
schema[col].Metadata.GetValue(metadataKind, ref value);
249249
if (max < value)
250250
{
251251
max = value;
@@ -264,11 +264,11 @@ internal static IEnumerable<int> GetColumnSet(this Schema schema, string metadat
264264
{
265265
for (int col = 0; col < schema.Count; col++)
266266
{
267-
var columnType = schema.GetMetadataTypeOrNull(metadataKind, col);
267+
var columnType = schema[col].Metadata.Schema.GetColumnOrNull(metadataKind)?.Type;
268268
if (columnType != null && columnType.IsKey && columnType.RawKind == DataKind.U4)
269269
{
270270
uint val = 0;
271-
schema.GetMetadata(metadataKind, col, ref val);
271+
schema[col].Metadata.GetValue(metadataKind, ref val);
272272
if (val == value)
273273
yield return col;
274274
}
@@ -284,11 +284,11 @@ internal static IEnumerable<int> GetColumnSet(this Schema schema, string metadat
284284
{
285285
for (int col = 0; col < schema.Count; col++)
286286
{
287-
var columnType = schema.GetMetadataTypeOrNull(metadataKind, col);
287+
var columnType = schema[col].Metadata.Schema.GetColumnOrNull(metadataKind)?.Type;
288288
if (columnType != null && columnType.IsText)
289289
{
290290
ReadOnlyMemory<char> val = default;
291-
schema.GetMetadata(metadataKind, col, ref val);
291+
schema[col].Metadata.GetValue(metadataKind, ref val);
292292
if (ReadOnlyMemoryUtils.EqualsStr(value, val))
293293
yield return col;
294294
}
@@ -338,7 +338,7 @@ internal static void GetSlotNames(RoleMappedSchema schema, RoleMappedSchema.Colu
338338
VBufferUtils.Resize(ref slotNames, vectorSize, 0);
339339
}
340340
else
341-
schema.Schema.GetMetadata(Kinds.SlotNames, list[0].Index, ref slotNames);
341+
schema.Schema[list[0].Index].Metadata.GetValue(Kinds.SlotNames, ref slotNames);
342342
}
343343

344344
[BestFriend]
@@ -447,14 +447,14 @@ internal static bool TryGetCategoricalFeatureIndices(Schema schema, int colIndex
447447

448448
bool isValid = false;
449449
categoricalFeatures = null;
450-
if (!(schema.GetColumnType(colIndex) is VectorType vecType && vecType.Size > 0))
450+
if (!(schema[colIndex].Type is VectorType vecType && vecType.Size > 0))
451451
return isValid;
452452

453-
var type = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.CategoricalSlotRanges, colIndex);
453+
var type = schema[colIndex].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.CategoricalSlotRanges)?.Type;
454454
if (type?.RawType == typeof(VBuffer<int>))
455455
{
456456
VBuffer<int> catIndices = default(VBuffer<int>);
457-
schema.GetMetadata(MetadataUtils.Kinds.CategoricalSlotRanges, colIndex, ref catIndices);
457+
schema[colIndex].Metadata.GetValue(MetadataUtils.Kinds.CategoricalSlotRanges, ref catIndices);
458458
VBufferUtils.Densify(ref catIndices);
459459
int columnSlotsCount = vecType.Size;
460460
if (catIndices.Length > 0 && catIndices.Length % 2 == 0 && catIndices.Length <= columnSlotsCount * 2)

src/Microsoft.ML.Core/Data/RoleMappedSchema.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ public static bool TryCreateFromName(Schema schema, string name, out ColumnInfo
5353
if (!schema.TryGetColumnIndex(name, out int index))
5454
return false;
5555

56-
colInfo = new ColumnInfo(name, index, schema.GetColumnType(index));
56+
colInfo = new ColumnInfo(name, index, schema[index].Type);
5757
return true;
5858
}
5959

@@ -65,9 +65,9 @@ public static bool TryCreateFromName(Schema schema, string name, out ColumnInfo
6565
public static ColumnInfo CreateFromIndex(Schema schema, int index)
6666
{
6767
Contracts.CheckValue(schema, nameof(schema));
68-
Contracts.CheckParam(0 <= index && index < schema.ColumnCount, nameof(index));
68+
Contracts.CheckParam(0 <= index && index < schema.Count, nameof(index));
6969

70-
return new ColumnInfo(schema.GetColumnName(index), index, schema.GetColumnType(index));
70+
return new ColumnInfo(schema[index].Name, index, schema[index].Type);
7171
}
7272
}
7373

src/Microsoft.ML.Core/Data/Schema.cs

+26-37
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,6 @@ public sealed class Schema : ISchema, IReadOnlyList<Schema.Column>
2727
private readonly Column[] _columns;
2828
private readonly Dictionary<string, int> _nameMap;
2929

30-
/// <summary>
31-
/// Number of columns in the schema.
32-
/// </summary>
33-
public int ColumnCount => _columns.Length;
34-
3530
/// <summary>
3631
/// Number of columns in the schema.
3732
/// </summary>
@@ -249,7 +244,7 @@ public void GetValue<TValue>(string kind, ref TValue value)
249244
GetGetter<TValue>(column.Value.Index)(ref value);
250245
}
251246

252-
public override string ToString() => string.Join(", ", Schema.GetColumns().Select(x => x.column.Name));
247+
public override string ToString() => string.Join(", ", Schema.Select(x => x.Name));
253248

254249
}
255250

@@ -270,11 +265,6 @@ internal Schema(Column[] columns)
270265
}
271266
}
272267

273-
/// <summary>
274-
/// Get all non-hidden columns as pairs of (index, <see cref="Column"/>).
275-
/// </summary>
276-
public IEnumerable<(int index, Column column)> GetColumns() => _nameMap.Values.Select(idx => (idx, _columns[idx]));
277-
278268
/// <summary>
279269
/// Manufacture an instance of <see cref="Schema"/> out of any <see cref="ISchema"/>.
280270
/// </summary>
@@ -310,38 +300,37 @@ private static Delegate GetMetadataGetterDelegate<TValue>(ISchema schema, int co
310300
return getter;
311301
}
312302

303+
/// <summary>
304+
/// Legacy method to get the column index.
305+
/// DO NOT USE: use <see cref="GetColumnOrNull"/> instead.
306+
/// </summary>
307+
[BestFriend]
308+
internal bool TryGetColumnIndex(string name, out int col)
309+
{
310+
col = GetColumnOrNull(name)?.Index ?? -1;
311+
return col >= 0;
312+
}
313+
313314
#region Legacy schema API to be removed
314-
public string GetColumnName(int col) => this[col].Name;
315+
/// <summary>
316+
/// Number of columns in the schema.
317+
/// </summary>
318+
int ISchema.ColumnCount => _columns.Length;
315319

316-
public ColumnType GetColumnType(int col) => this[col].Type;
320+
string ISchema.GetColumnName(int col) => this[col].Name;
317321

318-
public IEnumerable<KeyValuePair<string, ColumnType>> GetMetadataTypes(int col)
319-
{
320-
var meta = this[col].Metadata;
321-
if (meta == null)
322-
return Enumerable.Empty<KeyValuePair<string, ColumnType>>();
323-
return meta.Schema.GetColumns().Select(c => new KeyValuePair<string, ColumnType>(c.column.Name, c.column.Type));
324-
}
322+
ColumnType ISchema.GetColumnType(int col) => this[col].Type;
325323

326-
public ColumnType GetMetadataTypeOrNull(string kind, int col)
327-
{
328-
var meta = this[col].Metadata;
329-
if (meta == null)
330-
return null;
331-
if (meta.Schema.TryGetColumnIndex(kind, out int metaCol))
332-
return meta.Schema[metaCol].Type;
333-
return null;
334-
}
324+
IEnumerable<KeyValuePair<string, ColumnType>> ISchema.GetMetadataTypes(int col)
325+
=> this[col].Metadata.Schema.Select(c => new KeyValuePair<string, ColumnType>(c.Name, c.Type));
335326

336-
public void GetMetadata<TValue>(string kind, int col, ref TValue value)
337-
{
338-
var meta = this[col].Metadata;
339-
if (meta == null)
340-
throw MetadataUtils.ExceptGetMetadata();
341-
meta.GetValue(kind, ref value);
342-
}
327+
ColumnType ISchema.GetMetadataTypeOrNull(string kind, int col)
328+
=> this[col].Metadata.Schema.GetColumnOrNull(kind)?.Type;
329+
330+
void ISchema.GetMetadata<TValue>(string kind, int col, ref TValue value)
331+
=> this[col].Metadata.GetValue(kind, ref value);
343332

344-
public bool TryGetColumnIndex(string name, out int col)
333+
bool ISchema.TryGetColumnIndex(string name, out int col)
345334
{
346335
col = GetColumnOrNull(name)?.Index ?? -1;
347336
return col >= 0;

src/Microsoft.ML.Core/Data/SchemaDebuggerProxy.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@ public MetadataDebuggerProxy(Schema.Metadata metadata)
3939
private static List<KeyValuePair<string, object>> BuildValues(Schema.Metadata metadata)
4040
{
4141
var result = new List<KeyValuePair<string, object>>();
42-
foreach ((var index, var column) in metadata.Schema.GetColumns())
42+
foreach (var column in metadata.Schema)
4343
{
4444
var name = column.Name;
45-
var value = Utils.MarshalInvoke(GetValue<int>, column.Type.RawType, metadata, index);
45+
var value = Utils.MarshalInvoke(GetValue<int>, column.Type.RawType, metadata, column.Index);
4646
result.Add(new KeyValuePair<string, object>(name, value));
4747
}
4848
return result;

src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output
298298
if (group != null && schema.TryGetColumnIndex(group, out index))
299299
{
300300
// Check if group column key type with known cardinality.
301-
var type = schema.GetColumnType(index);
301+
var type = schema[index].Type;
302302
if (type.KeyCount > 0)
303303
stratificationColumn = group;
304304
}
@@ -322,7 +322,7 @@ private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output
322322
int col;
323323
if (!input.Schema.TryGetColumnIndex(stratificationColumn, out col))
324324
throw ch.ExceptUserArg(nameof(Arguments.StratificationColumn), "Column '{0}' does not exist", stratificationColumn);
325-
var type = input.Schema.GetColumnType(col);
325+
var type = input.Schema[col].Type;
326326
if (!RangeFilter.IsValidRangeFilterColumnType(ch, type))
327327
{
328328
ch.Info("Hashing the stratification column");

src/Microsoft.ML.Data/Commands/DataCommand.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,8 @@ protected void SendTelemetryMetric(Dictionary<string, IDataView>[] metricValues)
165165
{
166166
for (int currentIndex = 0; currentIndex < cursor.Schema.Count; currentIndex++)
167167
{
168-
var nameOfMetric = "TLC_" + cursor.Schema.GetColumnName(currentIndex);
169-
var type = cursor.Schema.GetColumnType(currentIndex);
168+
var nameOfMetric = "TLC_" + cursor.Schema[currentIndex].Name;
169+
var type = cursor.Schema[currentIndex].Type;
170170
if (type.IsNumber)
171171
{
172172
var getter = RowCursorUtils.GetGetterAs<double>(NumberType.R8, cursor, currentIndex);

src/Microsoft.ML.Data/Commands/SaveDataCommand.cs

+5-5
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,11 @@ private void RunCore(IChannel ch)
146146
{
147147
if (!Args.KeepHidden && data.Schema[i].IsHidden)
148148
continue;
149-
var type = data.Schema.GetColumnType(i);
149+
var type = data.Schema[i].Type;
150150
if (saver.IsColumnSavable(type))
151151
cols.Add(i);
152152
else
153-
ch.Info(MessageSensitivity.Schema, "The column '{0}' will not be written as it has unsavable column type.", data.Schema.GetColumnName(i));
153+
ch.Info(MessageSensitivity.Schema, "The column '{0}' will not be written as it has unsavable column type.", data.Schema[i].Name);
154154
}
155155
Host.NotSensitive().Check(cols.Count > 0, "No valid columns to save");
156156

@@ -203,15 +203,15 @@ public static void SaveDataView(IChannel ch, IDataSaver saver, IDataView view, S
203203
ch.CheckValue(stream, nameof(stream));
204204

205205
var cols = new List<int>();
206-
for (int i = 0; i < view.Schema.ColumnCount; i++)
206+
for (int i = 0; i < view.Schema.Count; i++)
207207
{
208208
if (!keepHidden && view.Schema[i].IsHidden)
209209
continue;
210-
var type = view.Schema.GetColumnType(i);
210+
var type = view.Schema[i].Type;
211211
if (saver.IsColumnSavable(type))
212212
cols.Add(i);
213213
else
214-
ch.Info(MessageSensitivity.Schema, "The column '{0}' will not be written as it has unsavable column type.", view.Schema.GetColumnName(i));
214+
ch.Info(MessageSensitivity.Schema, "The column '{0}' will not be written as it has unsavable column type.", view.Schema[i].Name);
215215
}
216216

217217
ch.Check(cols.Count > 0, "No valid columns to save");

src/Microsoft.ML.Data/Commands/ScoreCommand.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -185,13 +185,13 @@ private void RunCore(IChannel ch)
185185
continue;
186186
if (!(outputAllColumns || ShouldAddColumn(loader.Schema, i, maxScoreId, outputNamesAndLabels)))
187187
continue;
188-
var type = loader.Schema.GetColumnType(i);
188+
var type = loader.Schema[i].Type;
189189
if (writer.IsColumnSavable(type))
190190
cols.Add(i);
191191
else
192192
{
193193
ch.Warning("The column '{0}' will not be written as it has unsavable column type.",
194-
loader.Schema.GetColumnName(i));
194+
loader.Schema[i].Name);
195195
}
196196
}
197197

@@ -217,7 +217,7 @@ private bool ShouldAddColumn(Schema schema, int i, uint scoreSet, bool outputNam
217217
}
218218
if (outputNamesAndLabels)
219219
{
220-
switch (schema.GetColumnName(i))
220+
switch (schema[i].Name)
221221
{
222222
case "Label":
223223
case "Name":
@@ -227,7 +227,7 @@ private bool ShouldAddColumn(Schema schema, int i, uint scoreSet, bool outputNam
227227
break;
228228
}
229229
}
230-
if (Args.OutputColumn != null && Array.FindIndex(Args.OutputColumn, schema.GetColumnName(i).Equals) >= 0)
230+
if (Args.OutputColumn != null && Array.FindIndex(Args.OutputColumn, schema[i].Name.Equals) >= 0)
231231
return true;
232232
return false;
233233
}

0 commit comments

Comments
 (0)