Skip to content

Commit 4dac819

Browse files
author
Pete Luferenko
committed
ColumnBindingsBase and ISchema are now internal. Schema and ColumnBindingsBase no longer
implement ISchema.
1 parent 35ce1c0 commit 4dac819

33 files changed

+148
-396
lines changed

src/Microsoft.ML.Core/Data/IDataView.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ namespace Microsoft.ML.Runtime.Data
1212
/// Legacy interface for schema information.
1313
/// Please avoid implementing this interface, use <see cref="Schema"/>.
1414
/// </summary>
15-
public interface ISchema
15+
[BestFriend]
16+
internal interface ISchema
1617
{
1718
/// <summary>
1819
/// Number of columns.

src/Microsoft.ML.Core/Data/Schema.cs

+1-30
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ namespace Microsoft.ML.Data
2222
/// and values.
2323
/// </summary>
2424
[System.Diagnostics.DebuggerTypeProxy(typeof(SchemaDebuggerProxy))]
25-
public sealed class Schema : ISchema, IReadOnlyList<Schema.Column>
25+
public sealed class Schema : IReadOnlyList<Schema.Column>
2626
{
2727
private readonly Column[] _columns;
2828
private readonly Dictionary<string, int> _nameMap;
@@ -273,9 +273,6 @@ internal static Schema Create(ISchema inputSchema)
273273
{
274274
Contracts.CheckValue(inputSchema, nameof(inputSchema));
275275

276-
if (inputSchema is Schema s)
277-
return s;
278-
279276
var builder = new SchemaBuilder();
280277
for (int i = 0; i < inputSchema.ColumnCount; i++)
281278
{
@@ -310,31 +307,5 @@ internal bool TryGetColumnIndex(string name, out int col)
310307
col = GetColumnOrNull(name)?.Index ?? -1;
311308
return col >= 0;
312309
}
313-
314-
#region Legacy schema API to be removed
315-
/// <summary>
316-
/// Number of columns in the schema.
317-
/// </summary>
318-
int ISchema.ColumnCount => _columns.Length;
319-
320-
string ISchema.GetColumnName(int col) => this[col].Name;
321-
322-
ColumnType ISchema.GetColumnType(int col) => this[col].Type;
323-
324-
IEnumerable<KeyValuePair<string, ColumnType>> ISchema.GetMetadataTypes(int col)
325-
=> this[col].Metadata.Schema.Select(c => new KeyValuePair<string, ColumnType>(c.Name, c.Type));
326-
327-
ColumnType ISchema.GetMetadataTypeOrNull(string kind, int col)
328-
=> this[col].Metadata.Schema.GetColumnOrNull(kind)?.Type;
329-
330-
void ISchema.GetMetadata<TValue>(string kind, int col, ref TValue value)
331-
=> this[col].Metadata.GetValue(kind, ref value);
332-
333-
bool ISchema.TryGetColumnIndex(string name, out int col)
334-
{
335-
col = GetColumnOrNull(name)?.Index ?? -1;
336-
return col >= 0;
337-
}
338-
#endregion
339310
}
340311
}

src/Microsoft.ML.Data/Data/ITransposeDataView.cs

+4-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ namespace Microsoft.ML.Runtime.Data
2525
/// is accessible in this fashion iff <see cref="TransposeSchema"/>'s
2626
/// <see cref="ITransposeSchema.GetSlotType"/> returns a non-null value.
2727
/// </summary>
28-
public interface ITransposeDataView : IDataView
28+
[BestFriend]
29+
internal interface ITransposeDataView : IDataView
2930
{
3031
/// <summary>
3132
/// An enhanced schema, containing information on the transposition properties, if any,
@@ -44,7 +45,8 @@ public interface ITransposeDataView : IDataView
4445
/// <summary>
4546
/// The transpose schema returns the schema information of the view we have transposed.
4647
/// </summary>
47-
public interface ITransposeSchema : ISchema
48+
[BestFriend]
49+
internal interface ITransposeSchema : ISchema
4850
{
4951
/// <summary>
5052
/// Analogous to <see cref="ISchema.GetColumnType"/>, except instead of returning the type of value

src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs

+4-3
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ private CompositeDataLoader(IHost host, TransformEx[] transforms)
409409

410410
View = transforms[transforms.Length - 1].Transform;
411411
_tview = View as ITransposeDataView;
412-
TransposeSchema = _tview?.TransposeSchema ?? new TransposerUtils.SimpleTransposeSchema(View.Schema);
412+
_transposeSchema = _tview?.TransposeSchema ?? new TransposerUtils.SimpleTransposeSchema(View.Schema);
413413

414414
var srcLoader = transforms[0].Transform.Source as IDataLoader;
415415

@@ -566,7 +566,8 @@ private static string GenerateTag(int index)
566566

567567
public Schema Schema => View.Schema;
568568

569-
public ITransposeSchema TransposeSchema { get; }
569+
private readonly ITransposeSchema _transposeSchema;
570+
ITransposeSchema ITransposeDataView.TransposeSchema => _transposeSchema;
570571

571572
public RowCursor GetRowCursor(Func<int, bool> predicate, Random rand = null)
572573
{
@@ -586,7 +587,7 @@ public RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator,
586587
public SlotCursor GetSlotCursor(int col)
587588
{
588589
_host.CheckParam(0 <= col && col < Schema.Count, nameof(col));
589-
if (TransposeSchema?.GetSlotType(col) == null)
590+
if (_transposeSchema?.GetSlotType(col) == null)
590591
{
591592
throw _host.ExceptParam(nameof(col), "Bad call to GetSlotCursor on untransposable column '{0}'",
592593
Schema[col].Name);

src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs

+5-5
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ private static VersionInfo GetVersionInfo()
367367
// to use the cursors from the schema view if convenient to do so.
368368
public Schema Schema { get { return _schemaEntry.GetView().Schema; } }
369369

370-
public ITransposeSchema TransposeSchema { get { return _schema; } }
370+
ITransposeSchema ITransposeDataView.TransposeSchema { get { return _schema; } }
371371

372372
/// <summary>
373373
/// Whether the master schema sub-IDV has the actual data.
@@ -698,7 +698,7 @@ public SlotCursor GetSlotCursor(int col)
698698
_host.CheckParam(0 <= col && col < _header.ColumnCount, nameof(col));
699699
// We don't want the type error, if there is one, to be handled by the get-getter, because
700700
// at the point we've gotten the interior cursor, but not yet constructed the slot cursor.
701-
ColumnType cursorType = TransposeSchema.GetSlotType(col).ItemType;
701+
ColumnType cursorType = _schema.GetSlotType(col).ItemType;
702702
RowCursor inputCursor = view.GetRowCursor(c => true);
703703
try
704704
{
@@ -783,8 +783,8 @@ private Transposer EnsureAndGetTransposer(int col)
783783
_host.AssertValue(view);
784784
_host.Assert(view.Schema.Count == 1);
785785
var trans = _colTransposers[col] = Transposer.Create(_host, view, false, new int[] { 0 });
786-
_host.Assert(trans.TransposeSchema.ColumnCount == 1);
787-
_host.Assert(trans.TransposeSchema.GetSlotType(0).ValueCount == Schema[col].Type.ValueCount);
786+
_host.Assert(((ITransposeDataView)trans).TransposeSchema.ColumnCount == 1);
787+
_host.Assert(((ITransposeDataView)trans).TransposeSchema.GetSlotType(0).ValueCount == Schema[col].Type.ValueCount);
788788
}
789789
}
790790
}
@@ -845,7 +845,7 @@ private void Init(int col)
845845
Ch.Assert(0 <= col && col < Schema.Count);
846846
Ch.Assert(_colToActivesIndex[col] >= 0);
847847
var type = Schema[col].Type;
848-
Ch.Assert(_parent.TransposeSchema.GetSlotType(col).ValueCount == _parent._header.RowCount);
848+
Ch.Assert(((ITransposeDataView)_parent).TransposeSchema.GetSlotType(col).ValueCount == _parent._header.RowCount);
849849
Action<int> func = InitOne<int>;
850850
if (type.IsVector)
851851
func = InitVec<int>;

src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ internal RowToRowMapperTransform(IHostEnvironment env, IDataView input, IRowMapp
9292
Contracts.CheckValueOrNull(mapperFactory);
9393
_mapper = mapper;
9494
_mapperFactory = mapperFactory;
95-
_bindings = new ColumnBindings(Schema.Create(input.Schema), mapper.GetOutputColumns());
95+
_bindings = new ColumnBindings(input.Schema, mapper.GetOutputColumns());
9696
}
9797

9898
[BestFriend]
@@ -110,7 +110,7 @@ private RowToRowMapperTransform(IHost host, ModelLoadContext ctx, IDataView inpu
110110
// _mapper
111111

112112
ctx.LoadModel<IRowMapper, SignatureLoadRowMapper>(host, out _mapper, "Mapper", input.Schema);
113-
_bindings = new ColumnBindings(Schema.Create(input.Schema), _mapper.GetOutputColumns());
113+
_bindings = new ColumnBindings(input.Schema, _mapper.GetOutputColumns());
114114
}
115115

116116
public static RowToRowMapperTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)

src/Microsoft.ML.Data/DataView/Transposer.cs

+4-3
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ namespace Microsoft.ML.Runtime.Data
2121
/// were not transposable before. Note that transposition is a somewhat slow and resource intensive
2222
/// operation.
2323
/// </summary>
24-
public sealed class Transposer : ITransposeDataView, IDisposable
24+
[BestFriend]
25+
internal sealed class Transposer : ITransposeDataView, IDisposable
2526
{
2627
private readonly IHost _host;
2728
// The input view.
@@ -384,7 +385,7 @@ public override ValueGetter<VBuffer<TValue>> GetGetter<TValue>()
384385

385386
public override VectorType GetSlotType()
386387
{
387-
return _parent.TransposeSchema.GetSlotType(_col);
388+
return ((ITransposeDataView)_parent).TransposeSchema.GetSlotType(_col);
388389
}
389390

390391
protected abstract ValueGetter<VBuffer<T>> GetGetterCore();
@@ -1402,7 +1403,7 @@ public override ValueGetter<TValue> GetGetter<TValue>(int col)
14021403
}
14031404
}
14041405

1405-
public static class TransposerUtils
1406+
internal static class TransposerUtils
14061407
{
14071408
/// <summary>
14081409
/// This is a convenience method that extracts a single slot value's vector,

src/Microsoft.ML.Data/EntryPoints/TransformModelImpl.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ public CompositeRowToRowMapper(IExceptionContext ectx, IDataView chain, Schema r
207207
_ectx.CheckValue(rootSchema, nameof(rootSchema));
208208

209209
_chain = chain;
210-
_rootSchema = Schema.Create(rootSchema);
210+
_rootSchema = rootSchema;
211211
}
212212

213213
public static bool IsCompositeRowToRowMapper(IDataView chain)

src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ public static MultiOutputRegressionPerInstanceEvaluator Create(IHostEnvironment
422422
env.CheckValue(ctx, nameof(ctx));
423423
ctx.CheckAtModel(GetVersionInfo());
424424

425-
return new MultiOutputRegressionPerInstanceEvaluator(env, ctx, Schema.Create(schema));
425+
return new MultiOutputRegressionPerInstanceEvaluator(env, ctx, schema);
426426
}
427427

428428
public override void Save(ModelSaveContext ctx)

src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,7 @@ private sealed class Bindings : BindingsBase
629629
private readonly int _truncationLevel;
630630
private readonly MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>> _slotNamesGetter;
631631

632-
public Bindings(IExceptionContext ectx, ISchema input, bool user, string labelCol, string scoreCol, string groupCol,
632+
public Bindings(IExceptionContext ectx, Schema input, bool user, string labelCol, string scoreCol, string groupCol,
633633
int truncationLevel)
634634
: base(ectx, input, labelCol, scoreCol, groupCol, user, Ndcg, Dcg, MaxDcg)
635635
{
@@ -736,7 +736,7 @@ public override void Save(ModelSaveContext ctx)
736736
ctx.Writer.WriteDoubleArray(_labelGains);
737737
}
738738

739-
protected override BindingsBase GetBindings()
739+
private protected override BindingsBase GetBindings()
740740
{
741741
return _bindings;
742742
}

src/Microsoft.ML.Data/Scorers/GenericScorer.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ internal GenericScorer(IHostEnvironment env, ScorerArgumentsBase args, IDataView
161161
var rowMapper = mapper as ISchemaBoundRowMapper;
162162
Host.CheckParam(rowMapper != null, nameof(mapper), "mapper should implement ISchemaBoundRowMapper");
163163
_bindings = Bindings.Create(data.Schema, rowMapper, args.Suffix);
164-
OutputSchema = Schema.Create(_bindings);
164+
OutputSchema = _bindings.AsSchema;
165165
}
166166

167167
/// <summary>
@@ -171,7 +171,7 @@ private GenericScorer(IHostEnvironment env, GenericScorer transform, IDataView d
171171
: base(env, data, RegistrationName, transform.Bindable)
172172
{
173173
_bindings = transform._bindings.ApplyToSchema(env, data.Schema);
174-
OutputSchema = Schema.Create(_bindings);
174+
OutputSchema = _bindings.AsSchema;
175175
}
176176

177177
/// <summary>
@@ -182,7 +182,7 @@ private GenericScorer(IHost host, ModelLoadContext ctx, IDataView input)
182182
{
183183
Contracts.AssertValue(ctx);
184184
_bindings = Bindings.Create(ctx, host, Bindable, input.Schema);
185-
OutputSchema = Schema.Create(_bindings);
185+
OutputSchema = _bindings.AsSchema;
186186
}
187187

188188
/// <summary>

src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -307,15 +307,15 @@ private protected PredictedLabelScorerBase(ScorerArgumentsBase args, IHostEnviro
307307
var predColType = getPredColType(scoreType, rowMapper);
308308

309309
Bindings = BindingsImpl.Create(data.Schema, rowMapper, args.Suffix, scoreColKind, scoreColIndex, predColType);
310-
OutputSchema = Schema.Create(Bindings);
310+
OutputSchema = Bindings.AsSchema;
311311
}
312312

313313
protected PredictedLabelScorerBase(IHostEnvironment env, PredictedLabelScorerBase transform,
314314
IDataView newSource, string registrationName)
315315
: base(env, newSource, registrationName, transform.Bindable)
316316
{
317317
Bindings = transform.Bindings.ApplyToSchema(newSource.Schema, Bindable, env);
318-
OutputSchema = Schema.Create(Bindings);
318+
OutputSchema = Bindings.AsSchema;
319319
}
320320

321321
[BestFriend]
@@ -329,7 +329,7 @@ private protected PredictedLabelScorerBase(IHost host, ModelLoadContext ctx, IDa
329329
Host.AssertValue(getPredColType);
330330

331331
Bindings = BindingsImpl.Create(ctx, input.Schema, host, Bindable, outputTypeMatches, getPredColType);
332-
OutputSchema = Schema.Create(Bindings);
332+
OutputSchema = Bindings.AsSchema;
333333
}
334334

335335
private protected override void SaveCore(ModelSaveContext ctx)

src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ private static bool[] GetActive(BindingsBase bindings, Func<int, bool> predicate
8686
Contracts.Assert(active.Length == bindings.ColumnCount);
8787

8888
var activeInput = bindings.GetActiveInput(predicate);
89-
Contracts.Assert(activeInput.Length == bindings.Input.ColumnCount);
89+
Contracts.Assert(activeInput.Length == bindings.Input.Count);
9090

9191
// Get a predicate that determines which Mapper outputs are active.
9292
predicateMapper = bindings.GetActiveMapperColumns(active);

0 commit comments

Comments
 (0)