Skip to content

Commit d7d4e99

Browse files
authored
Make SchemaShape.Column a struct instead of a class (#1822)
* Make SchemaShape.Column a struct instead of a class * Make SchemaShape an IReadOnlyList<SchemaShape.Column> and remove SchemaException * Make a (best) friend * Address comments * Make all member functions of SchemaShape.Column be best friends * Address comments * Change Columns' type from Column[] to ImmutableList<Column> * Fix build * Remove unnecessary checks because for some nullable columns * Address comments * SchemaShape is only a list not a class containing a list * Use array for immutable field and add a best friend function * Assert doesn't need variable name
1 parent 4c047bf commit d7d4e99

File tree

62 files changed

+145
-141
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+145
-141
lines changed

src/Microsoft.ML.Api/CustomMappingTransformer.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
205205
var addedCols = DataViewConstructionUtils.GetSchemaColumns(Transformer.AddedSchema);
206206
var addedSchemaShape = SchemaShape.Create(SchemaBuilder.MakeSchema(addedCols));
207207

208-
var result = inputSchema.Columns.ToDictionary(x => x.Name);
208+
var result = inputSchema.ToDictionary(x => x.Name);
209209
var inputDef = InternalSchemaDefinition.Create(typeof(TSrc), Transformer.InputSchemaDefinition);
210210
foreach (var col in inputDef.Columns)
211211
{
@@ -223,7 +223,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
223223
}
224224
}
225225

226-
foreach (var addedCol in addedSchemaShape.Columns)
226+
foreach (var addedCol in addedSchemaShape)
227227
result[addedCol.Name] = addedCol;
228228

229229
return new SchemaShape(result.Values);

src/Microsoft.ML.Core/Data/IEstimator.cs

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
using Microsoft.ML.Data;
66
using Microsoft.ML.Runtime;
77
using Microsoft.ML.Runtime.Data;
8-
using System;
8+
using System.Collections;
99
using System.Collections.Generic;
10+
using System.Collections.Immutable;
1011
using System.Linq;
1112

1213
namespace Microsoft.ML.Core.Data
@@ -16,13 +17,17 @@ namespace Microsoft.ML.Core.Data
1617
/// This is more relaxed than the proper <see cref="ISchema"/>, since it's only a subset of the columns,
1718
/// and also since it doesn't specify exact <see cref="ColumnType"/>'s for vectors and keys.
1819
/// </summary>
19-
public sealed class SchemaShape
20+
public sealed class SchemaShape : IReadOnlyList<SchemaShape.Column>
2021
{
21-
public readonly Column[] Columns;
22+
private readonly Column[] _columns;
2223

2324
private static readonly SchemaShape _empty = new SchemaShape(Enumerable.Empty<Column>());
2425

25-
public sealed class Column
26+
public int Count => _columns.Count();
27+
28+
public Column this[int index] => _columns[index];
29+
30+
public struct Column
2631
{
2732
public enum VectorKind
2833
{
@@ -55,13 +60,13 @@ public enum VectorKind
5560
/// </summary>
5661
public readonly SchemaShape Metadata;
5762

58-
public Column(string name, VectorKind vecKind, ColumnType itemType, bool isKey, SchemaShape metadata = null)
63+
[BestFriend]
64+
internal Column(string name, VectorKind vecKind, ColumnType itemType, bool isKey, SchemaShape metadata = null)
5965
{
6066
Contracts.CheckNonEmpty(name, nameof(name));
6167
Contracts.CheckValueOrNull(metadata);
6268
Contracts.CheckParam(!itemType.IsKey, nameof(itemType), "Item type cannot be a key");
6369
Contracts.CheckParam(!itemType.IsVector, nameof(itemType), "Item type cannot be a vector");
64-
6570
Contracts.CheckParam(!isKey || KeyType.IsValidDataKind(itemType.RawKind), nameof(itemType), "The item type must be valid for a key");
6671

6772
Name = name;
@@ -80,9 +85,10 @@ public Column(string name, VectorKind vecKind, ColumnType itemType, bool isKey,
8085
/// - The columns of <see cref="Metadata"/> of <paramref name="inputColumn"/> is a superset of our <see cref="Metadata"/> columns.
8186
/// - Each such metadata column is itself compatible with the input metadata column.
8287
/// </summary>
83-
public bool IsCompatibleWith(Column inputColumn)
88+
[BestFriend]
89+
internal bool IsCompatibleWith(Column inputColumn)
8490
{
85-
Contracts.CheckValue(inputColumn, nameof(inputColumn));
91+
Contracts.Check(inputColumn.IsValid, nameof(inputColumn));
8692
if (Name != inputColumn.Name)
8793
return false;
8894
if (Kind != inputColumn.Kind)
@@ -91,7 +97,7 @@ public bool IsCompatibleWith(Column inputColumn)
9197
return false;
9298
if (IsKey != inputColumn.IsKey)
9399
return false;
94-
foreach (var metaCol in Metadata.Columns)
100+
foreach (var metaCol in Metadata)
95101
{
96102
if (!inputColumn.Metadata.TryFindColumn(metaCol.Name, out var inputMetaCol))
97103
return false;
@@ -101,7 +107,8 @@ public bool IsCompatibleWith(Column inputColumn)
101107
return true;
102108
}
103109

104-
public string GetTypeString()
110+
[BestFriend]
111+
internal string GetTypeString()
105112
{
106113
string result = ItemType.ToString();
107114
if (IsKey)
@@ -112,13 +119,20 @@ public string GetTypeString()
112119
result = $"VarVector<{result}>";
113120
return result;
114121
}
122+
123+
/// <summary>
124+
/// Return if this structure is not identical to the default value of <see cref="Column"/>. If true,
125+
/// it means this structure is initialized properly and therefore considered as valid.
126+
/// </summary>
127+
[BestFriend]
128+
internal bool IsValid => Name != null;
115129
}
116130

117131
public SchemaShape(IEnumerable<Column> columns)
118132
{
119133
Contracts.CheckValue(columns, nameof(columns));
120-
Columns = columns.ToArray();
121-
Contracts.CheckParam(columns.All(c => c != null), nameof(columns), "No items should be null.");
134+
_columns = columns.ToArray();
135+
Contracts.CheckParam(columns.All(c => c.IsValid), nameof(columns), "Some items are not initialized properly.");
122136
}
123137

124138
/// <summary>
@@ -151,7 +165,8 @@ internal static void GetColumnTypeShape(ColumnType type,
151165
/// <summary>
152166
/// Create a schema shape out of the fully defined schema.
153167
/// </summary>
154-
public static SchemaShape Create(Schema schema)
168+
[BestFriend]
169+
internal static SchemaShape Create(Schema schema)
155170
{
156171
Contracts.CheckValue(schema, nameof(schema));
157172
var cols = new List<Column>();
@@ -179,25 +194,23 @@ public static SchemaShape Create(Schema schema)
179194
/// <summary>
180195
/// Returns if there is a column with a specified <paramref name="name"/> and if so stores it in <paramref name="column"/>.
181196
/// </summary>
182-
public bool TryFindColumn(string name, out Column column)
197+
[BestFriend]
198+
internal bool TryFindColumn(string name, out Column column)
183199
{
184200
Contracts.CheckValue(name, nameof(name));
185-
column = Columns.FirstOrDefault(x => x.Name == name);
186-
return column != null;
201+
column = _columns.FirstOrDefault(x => x.Name == name);
202+
return column.IsValid;
187203
}
188204

205+
public IEnumerator<Column> GetEnumerator() => ((IEnumerable<Column>)_columns).GetEnumerator();
206+
207+
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
208+
189209
// REVIEW: I think we should have an IsCompatible method to check if it's OK to use one schema shape
190210
// as an input to another schema shape. I started writing, but realized that there's more than one way to check for
191211
// the 'compatibility': as in, 'CAN be compatible' vs. 'WILL be compatible'.
192212
}
193213

194-
/// <summary>
195-
/// Exception class for schema validation errors.
196-
/// </summary>
197-
public class SchemaException : Exception
198-
{
199-
}
200-
201214
/// <summary>
202215
/// The 'data reader' takes a certain kind of input and turns it into an <see cref="IDataView"/>.
203216
/// </summary>
@@ -246,7 +259,6 @@ public interface ITransformer
246259
/// <summary>
247260
/// Schema propagation for transformers.
248261
/// Returns the output schema of the data, if the input schema is like the one provided.
249-
/// Throws <see cref="SchemaException"/> if the input schema is not valid for the transformer.
250262
/// </summary>
251263
Schema GetOutputSchema(Schema inputSchema);
252264

@@ -288,7 +300,6 @@ public interface IEstimator<out TTransformer>
288300
/// <summary>
289301
/// Schema propagation for estimators.
290302
/// Returns the output schema shape of the estimator, if the input schema shape is like the one provided.
291-
/// Throws <see cref="SchemaException"/> iff the input schema is not valid for the estimator.
292303
/// </summary>
293304
SchemaShape GetOutputSchema(SchemaShape inputSchema);
294305
}

src/Microsoft.ML.Core/Data/MetadataUtils.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ public static bool IsNormalized(this Schema schema, int col)
367367
/// of a scalar <see cref="BoolType"/> type, which we assume, if set, should be <c>true</c>.</returns>
368368
public static bool IsNormalized(this SchemaShape.Column col)
369369
{
370-
Contracts.CheckValue(col, nameof(col));
370+
Contracts.CheckParam(col.IsValid, nameof(col), "struct not initialized properly");
371371
return col.Metadata.TryFindColumn(Kinds.IsNormalized, out var metaCol)
372372
&& metaCol.Kind == SchemaShape.Column.VectorKind.Scalar && !metaCol.IsKey
373373
&& metaCol.ItemType == BoolType.Instance;
@@ -382,7 +382,7 @@ public static bool IsNormalized(this SchemaShape.Column col)
382382
/// <see cref="Kinds.SlotNames"/> metadata of definite sized vectors of text.</returns>
383383
public static bool HasSlotNames(this SchemaShape.Column col)
384384
{
385-
Contracts.CheckValue(col, nameof(col));
385+
Contracts.CheckParam(col.IsValid, nameof(col), "struct not initialized properly");
386386
return col.Kind == SchemaShape.Column.VectorKind.Vector
387387
&& col.Metadata.TryFindColumn(Kinds.SlotNames, out var metaCol)
388388
&& metaCol.Kind == SchemaShape.Column.VectorKind.Vector && !metaCol.IsKey

src/Microsoft.ML.Core/Utilities/Contracts.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,10 @@ public static void CheckAlive(this IHostEnvironment env)
758758
public static void CheckValueOrNull<T>(T val) where T : class
759759
{
760760
}
761+
762+
/// <summary>
763+
/// This documents that the parameter can legally be null.
764+
/// </summary>
761765
[Conditional("INVARIANT_CHECKS")]
762766
public static void CheckValueOrNull<T>(this IExceptionContext ctx, T val) where T : class
763767
{

src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,22 +30,22 @@ public FakeSchema(IHostEnvironment env, SchemaShape inputShape)
3030
{
3131
_env = env;
3232
_shape = inputShape;
33-
_colMap = Enumerable.Range(0, _shape.Columns.Length)
34-
.ToDictionary(idx => _shape.Columns[idx].Name, idx => idx);
33+
_colMap = Enumerable.Range(0, _shape.Count)
34+
.ToDictionary(idx => _shape[idx].Name, idx => idx);
3535
}
3636

37-
public int ColumnCount => _shape.Columns.Length;
37+
public int ColumnCount => _shape.Count;
3838

3939
public string GetColumnName(int col)
4040
{
4141
_env.Check(0 <= col && col < ColumnCount);
42-
return _shape.Columns[col].Name;
42+
return _shape[col].Name;
4343
}
4444

4545
public ColumnType GetColumnType(int col)
4646
{
4747
_env.Check(0 <= col && col < ColumnCount);
48-
var inputCol = _shape.Columns[col];
48+
var inputCol = _shape[col];
4949
return MakeColumnType(inputCol);
5050
}
5151

@@ -66,7 +66,7 @@ private static ColumnType MakeColumnType(SchemaShape.Column inputCol)
6666
public void GetMetadata<TValue>(string kind, int col, ref TValue value)
6767
{
6868
_env.Check(0 <= col && col < ColumnCount);
69-
var inputCol = _shape.Columns[col];
69+
var inputCol = _shape[col];
7070
var metaShape = inputCol.Metadata;
7171
if (metaShape == null || !metaShape.TryFindColumn(kind, out var metaColumn))
7272
throw _env.ExceptGetMetadata();
@@ -89,7 +89,7 @@ public void GetMetadata<TValue>(string kind, int col, ref TValue value)
8989
public ColumnType GetMetadataTypeOrNull(string kind, int col)
9090
{
9191
_env.Check(0 <= col && col < ColumnCount);
92-
var inputCol = _shape.Columns[col];
92+
var inputCol = _shape[col];
9393
var metaShape = inputCol.Metadata;
9494
if (metaShape == null || !metaShape.TryFindColumn(kind, out var metaColumn))
9595
return null;
@@ -99,12 +99,12 @@ public ColumnType GetMetadataTypeOrNull(string kind, int col)
9999
public IEnumerable<KeyValuePair<string, ColumnType>> GetMetadataTypes(int col)
100100
{
101101
_env.Check(0 <= col && col < ColumnCount);
102-
var inputCol = _shape.Columns[col];
102+
var inputCol = _shape[col];
103103
var metaShape = inputCol.Metadata;
104104
if (metaShape == null)
105105
return Enumerable.Empty<KeyValuePair<string, ColumnType>>();
106106

107-
return metaShape.Columns.Select(c => new KeyValuePair<string, ColumnType>(c.Name, MakeColumnType(c)));
107+
return metaShape.Select(c => new KeyValuePair<string, ColumnType>(c.Name, MakeColumnType(c)));
108108
}
109109
}
110110
}

src/Microsoft.ML.Data/StaticPipe/StaticSchemaShape.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ public void Check(IExceptionContext ectx, SchemaShape shape)
113113

114114
private static Type GetTypeOrNull(SchemaShape.Column col)
115115
{
116-
Contracts.AssertValue(col);
116+
Contracts.Assert(col.IsValid);
117117

118118
Type vecType = null;
119119

src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,11 @@ public abstract class TrainerEstimatorBase<TTransformer, TModel> : ITrainerEstim
5555
private protected TrainerEstimatorBase(IHost host,
5656
SchemaShape.Column feature,
5757
SchemaShape.Column label,
58-
SchemaShape.Column weight = null)
58+
SchemaShape.Column weight = default)
5959
{
6060
Contracts.CheckValue(host, nameof(host));
6161
Host = host;
62-
Host.CheckValue(feature, nameof(feature));
63-
Host.CheckValueOrNull(label);
64-
Host.CheckValueOrNull(weight);
62+
Host.CheckParam(feature.IsValid, nameof(feature), "not initialized properly");
6563

6664
FeatureColumn = feature;
6765
LabelColumn = label;
@@ -76,7 +74,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
7674

7775
CheckInputSchema(inputSchema);
7876

79-
var outColumns = inputSchema.Columns.ToDictionary(x => x.Name);
77+
var outColumns = inputSchema.ToDictionary(x => x.Name);
8078
foreach (var col in GetOutputColumnsCore(inputSchema))
8179
outColumns[col.Name] = col;
8280

@@ -102,7 +100,7 @@ private void CheckInputSchema(SchemaShape inputSchema)
102100
if (!FeatureColumn.IsCompatibleWith(featureCol))
103101
throw Host.Except($"Feature column '{FeatureColumn.Name}' is not compatible");
104102

105-
if (WeightColumn != null)
103+
if (WeightColumn.IsValid)
106104
{
107105
if (!inputSchema.TryFindColumn(WeightColumn.Name, out var weightCol))
108106
throw Host.Except($"Weight column '{WeightColumn.Name}' is not found");
@@ -112,7 +110,7 @@ private void CheckInputSchema(SchemaShape inputSchema)
112110

113111
// Special treatment for label column: we allow different types of labels, so the trainers
114112
// may define their own requirements on the label column.
115-
if (LabelColumn != null)
113+
if (LabelColumn.IsValid)
116114
{
117115
if (!inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol))
118116
throw Host.Except($"Label column '{LabelColumn.Name}' is not found");
@@ -122,8 +120,8 @@ private void CheckInputSchema(SchemaShape inputSchema)
122120

123121
protected virtual void CheckLabelCompatible(SchemaShape.Column labelCol)
124122
{
125-
Contracts.CheckValue(labelCol, nameof(labelCol));
126-
Contracts.AssertValue(LabelColumn);
123+
Contracts.CheckParam(labelCol.IsValid, nameof(labelCol), "not initialized properly");
124+
Host.Assert(LabelColumn.IsValid);
127125

128126
if (!LabelColumn.IsCompatibleWith(labelCol))
129127
throw Host.Except($"Label column '{LabelColumn.Name}' is not compatible");
@@ -133,20 +131,12 @@ protected TTransformer TrainTransformer(IDataView trainSet,
133131
IDataView validationSet = null, IPredictor initPredictor = null)
134132
{
135133
var cachedTrain = Info.WantCaching ? new CacheDataView(Host, trainSet, prefetch: null) : trainSet;
134+
var cachedValid = Info.WantCaching && validationSet != null ? new CacheDataView(Host, validationSet, prefetch: null) : validationSet;
136135

137-
var trainRoles = MakeRoles(cachedTrain);
136+
var trainRoleMapped = MakeRoles(cachedTrain);
137+
var validRoleMapped = validationSet == null ? null : MakeRoles(cachedValid);
138138

139-
RoleMappedData validRoles;
140-
141-
if (validationSet == null)
142-
validRoles = null;
143-
else
144-
{
145-
var cachedValid = Info.WantCaching ? new CacheDataView(Host, validationSet, prefetch: null) : validationSet;
146-
validRoles = MakeRoles(cachedValid);
147-
}
148-
149-
var pred = TrainModelCore(new TrainContext(trainRoles, validRoles, null, initPredictor));
139+
var pred = TrainModelCore(new TrainContext(trainRoleMapped, validRoleMapped, null, initPredictor));
150140
return MakeTransformer(pred, trainSet.Schema);
151141
}
152142

@@ -156,7 +146,7 @@ protected TTransformer TrainTransformer(IDataView trainSet,
156146
protected abstract TTransformer MakeTransformer(TModel model, Schema trainSchema);
157147

158148
protected virtual RoleMappedData MakeRoles(IDataView data) =>
159-
new RoleMappedData(data, label: LabelColumn?.Name, feature: FeatureColumn.Name, weight: WeightColumn?.Name);
149+
new RoleMappedData(data, label: LabelColumn.Name, feature: FeatureColumn.Name, weight: WeightColumn.Name);
160150

161151
IPredictor ITrainer.Train(TrainContext context) => ((ITrainer<TModel>)this).Train(context);
162152
}
@@ -178,16 +168,15 @@ public abstract class TrainerEstimatorBaseWithGroupId<TTransformer, TModel> : Tr
178168
public TrainerEstimatorBaseWithGroupId(IHost host,
179169
SchemaShape.Column feature,
180170
SchemaShape.Column label,
181-
SchemaShape.Column weight = null,
182-
SchemaShape.Column groupId = null)
171+
SchemaShape.Column weight = default,
172+
SchemaShape.Column groupId = default)
183173
:base(host, feature, label, weight)
184174
{
185-
Host.CheckValueOrNull(groupId);
186175
GroupIdColumn = groupId;
187176
}
188177

189178
protected override RoleMappedData MakeRoles(IDataView data) =>
190-
new RoleMappedData(data, label: LabelColumn?.Name, feature: FeatureColumn.Name, group: GroupIdColumn?.Name, weight: WeightColumn?.Name);
179+
new RoleMappedData(data, label: LabelColumn.Name, feature: FeatureColumn.Name, group: GroupIdColumn.Name, weight: WeightColumn.Name);
191180

192181
}
193182
}

0 commit comments

Comments
 (0)