Skip to content

ISchema is now internal #1917

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 19, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/Microsoft.ML.Core/Data/IDataView.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ namespace Microsoft.ML.Runtime.Data
/// Legacy interface for schema information.
/// Please avoid implementing this interface, use <see cref="Schema"/>.
/// </summary>
public interface ISchema
[BestFriend]
internal interface ISchema
{
/// <summary>
/// Number of columns.
Expand Down
31 changes: 1 addition & 30 deletions src/Microsoft.ML.Core/Data/Schema.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace Microsoft.ML.Data
/// and values.
/// </summary>
[System.Diagnostics.DebuggerTypeProxy(typeof(SchemaDebuggerProxy))]
public sealed class Schema : ISchema, IReadOnlyList<Schema.Column>
public sealed class Schema : IReadOnlyList<Schema.Column>
{
private readonly Column[] _columns;
private readonly Dictionary<string, int> _nameMap;
Expand Down Expand Up @@ -273,9 +273,6 @@ internal static Schema Create(ISchema inputSchema)
{
Contracts.CheckValue(inputSchema, nameof(inputSchema));

if (inputSchema is Schema s)
return s;

var builder = new SchemaBuilder();
for (int i = 0; i < inputSchema.ColumnCount; i++)
{
Expand Down Expand Up @@ -310,31 +307,5 @@ internal bool TryGetColumnIndex(string name, out int col)
col = GetColumnOrNull(name)?.Index ?? -1;
return col >= 0;
}

#region Legacy schema API to be removed
/// <summary>
/// Number of columns in the schema.
/// </summary>
int ISchema.ColumnCount => _columns.Length;

string ISchema.GetColumnName(int col) => this[col].Name;

ColumnType ISchema.GetColumnType(int col) => this[col].Type;

IEnumerable<KeyValuePair<string, ColumnType>> ISchema.GetMetadataTypes(int col)
=> this[col].Metadata.Schema.Select(c => new KeyValuePair<string, ColumnType>(c.Name, c.Type));

ColumnType ISchema.GetMetadataTypeOrNull(string kind, int col)
=> this[col].Metadata.Schema.GetColumnOrNull(kind)?.Type;

void ISchema.GetMetadata<TValue>(string kind, int col, ref TValue value)
=> this[col].Metadata.GetValue(kind, ref value);

bool ISchema.TryGetColumnIndex(string name, out int col)
{
col = GetColumnOrNull(name)?.Index ?? -1;
return col >= 0;
}
#endregion
}
}
6 changes: 4 additions & 2 deletions src/Microsoft.ML.Data/Data/ITransposeDataView.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ namespace Microsoft.ML.Runtime.Data
/// is accessible in this fashion iff <see cref="TransposeSchema"/>'s
/// <see cref="ITransposeSchema.GetSlotType"/> returns a non-null value.
/// </summary>
public interface ITransposeDataView : IDataView
[BestFriend]
internal interface ITransposeDataView : IDataView
{
/// <summary>
/// An enhanced schema, containing information on the transposition properties, if any,
Expand All @@ -44,7 +45,8 @@ public interface ITransposeDataView : IDataView
/// <summary>
/// The transpose schema returns the schema information of the view we have transposed.
/// </summary>
public interface ITransposeSchema : ISchema
[BestFriend]
internal interface ITransposeSchema : ISchema
{
/// <summary>
/// Analogous to <see cref="ISchema.GetColumnType"/>, except instead of returning the type of value
Expand Down
7 changes: 4 additions & 3 deletions src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ private CompositeDataLoader(IHost host, TransformEx[] transforms)

View = transforms[transforms.Length - 1].Transform;
_tview = View as ITransposeDataView;
TransposeSchema = _tview?.TransposeSchema ?? new TransposerUtils.SimpleTransposeSchema(View.Schema);
_transposeSchema = _tview?.TransposeSchema ?? new TransposerUtils.SimpleTransposeSchema(View.Schema);

var srcLoader = transforms[0].Transform.Source as IDataLoader;

Expand Down Expand Up @@ -566,7 +566,8 @@ private static string GenerateTag(int index)

public Schema Schema => View.Schema;

public ITransposeSchema TransposeSchema { get; }
private readonly ITransposeSchema _transposeSchema;
ITransposeSchema ITransposeDataView.TransposeSchema => _transposeSchema;

public RowCursor GetRowCursor(Func<int, bool> predicate, Random rand = null)
{
Expand All @@ -586,7 +587,7 @@ public RowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator,
public SlotCursor GetSlotCursor(int col)
{
_host.CheckParam(0 <= col && col < Schema.Count, nameof(col));
if (TransposeSchema?.GetSlotType(col) == null)
if (_transposeSchema?.GetSlotType(col) == null)
{
throw _host.ExceptParam(nameof(col), "Bad call to GetSlotCursor on untransposable column '{0}'",
Schema[col].Name);
Expand Down
10 changes: 5 additions & 5 deletions src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ private static VersionInfo GetVersionInfo()
// to use the cursors from the schema view if convenient to do so.
public Schema Schema { get { return _schemaEntry.GetView().Schema; } }

public ITransposeSchema TransposeSchema { get { return _schema; } }
ITransposeSchema ITransposeDataView.TransposeSchema { get { return _schema; } }

/// <summary>
/// Whether the master schema sub-IDV has the actual data.
Expand Down Expand Up @@ -698,7 +698,7 @@ public SlotCursor GetSlotCursor(int col)
_host.CheckParam(0 <= col && col < _header.ColumnCount, nameof(col));
// We don't want the type error, if there is one, to be handled by the get-getter, because
// at the point we've gotten the interior cursor, but not yet constructed the slot cursor.
ColumnType cursorType = TransposeSchema.GetSlotType(col).ItemType;
ColumnType cursorType = _schema.GetSlotType(col).ItemType;
RowCursor inputCursor = view.GetRowCursor(c => true);
try
{
Expand Down Expand Up @@ -783,8 +783,8 @@ private Transposer EnsureAndGetTransposer(int col)
_host.AssertValue(view);
_host.Assert(view.Schema.Count == 1);
var trans = _colTransposers[col] = Transposer.Create(_host, view, false, new int[] { 0 });
_host.Assert(trans.TransposeSchema.ColumnCount == 1);
_host.Assert(trans.TransposeSchema.GetSlotType(0).ValueCount == Schema[col].Type.ValueCount);
_host.Assert(((ITransposeDataView)trans).TransposeSchema.ColumnCount == 1);
_host.Assert(((ITransposeDataView)trans).TransposeSchema.GetSlotType(0).ValueCount == Schema[col].Type.ValueCount);
}
}
}
Expand Down Expand Up @@ -845,7 +845,7 @@ private void Init(int col)
Ch.Assert(0 <= col && col < Schema.Count);
Ch.Assert(_colToActivesIndex[col] >= 0);
var type = Schema[col].Type;
Ch.Assert(_parent.TransposeSchema.GetSlotType(col).ValueCount == _parent._header.RowCount);
Ch.Assert(((ITransposeDataView)_parent).TransposeSchema.GetSlotType(col).ValueCount == _parent._header.RowCount);
Action<int> func = InitOne<int>;
if (type.IsVector)
func = InitVec<int>;
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ internal RowToRowMapperTransform(IHostEnvironment env, IDataView input, IRowMapp
Contracts.CheckValueOrNull(mapperFactory);
_mapper = mapper;
_mapperFactory = mapperFactory;
_bindings = new ColumnBindings(Schema.Create(input.Schema), mapper.GetOutputColumns());
_bindings = new ColumnBindings(input.Schema, mapper.GetOutputColumns());
}

[BestFriend]
Expand All @@ -110,7 +110,7 @@ private RowToRowMapperTransform(IHost host, ModelLoadContext ctx, IDataView inpu
// _mapper

ctx.LoadModel<IRowMapper, SignatureLoadRowMapper>(host, out _mapper, "Mapper", input.Schema);
_bindings = new ColumnBindings(Schema.Create(input.Schema), _mapper.GetOutputColumns());
_bindings = new ColumnBindings(input.Schema, _mapper.GetOutputColumns());
}

public static RowToRowMapperTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
Expand Down
5 changes: 3 additions & 2 deletions src/Microsoft.ML.Data/DataView/Transposer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ namespace Microsoft.ML.Runtime.Data
/// were not transposable before. Note that transposition is a somewhat slow and resource intensive
/// operation.
/// </summary>
public sealed class Transposer : ITransposeDataView, IDisposable
[BestFriend]
internal sealed class Transposer : ITransposeDataView, IDisposable
{
private readonly IHost _host;
// The input view.
Expand Down Expand Up @@ -1402,7 +1403,7 @@ public override ValueGetter<TValue> GetGetter<TValue>(int col)
}
}

public static class TransposerUtils
internal static class TransposerUtils
{
/// <summary>
/// This is a convenience method that extracts a single slot value's vector,
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/EntryPoints/TransformModelImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ public CompositeRowToRowMapper(IExceptionContext ectx, IDataView chain, Schema r
_ectx.CheckValue(rootSchema, nameof(rootSchema));

_chain = chain;
_rootSchema = Schema.Create(rootSchema);
_rootSchema = rootSchema;
}

public static bool IsCompositeRowToRowMapper(IDataView chain)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ public static MultiOutputRegressionPerInstanceEvaluator Create(IHostEnvironment
env.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());

return new MultiOutputRegressionPerInstanceEvaluator(env, ctx, Schema.Create(schema));
return new MultiOutputRegressionPerInstanceEvaluator(env, ctx, schema);
}

public override void Save(ModelSaveContext ctx)
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ private sealed class Bindings : BindingsBase
private readonly int _truncationLevel;
private readonly MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>> _slotNamesGetter;

public Bindings(IExceptionContext ectx, ISchema input, bool user, string labelCol, string scoreCol, string groupCol,
public Bindings(IExceptionContext ectx, Schema input, bool user, string labelCol, string scoreCol, string groupCol,
int truncationLevel)
: base(ectx, input, labelCol, scoreCol, groupCol, user, Ndcg, Dcg, MaxDcg)
{
Expand Down Expand Up @@ -736,7 +736,7 @@ public override void Save(ModelSaveContext ctx)
ctx.Writer.WriteDoubleArray(_labelGains);
}

protected override BindingsBase GetBindings()
private protected override BindingsBase GetBindings()
{
return _bindings;
}
Expand Down
6 changes: 3 additions & 3 deletions src/Microsoft.ML.Data/Scorers/GenericScorer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ internal GenericScorer(IHostEnvironment env, ScorerArgumentsBase args, IDataView
var rowMapper = mapper as ISchemaBoundRowMapper;
Host.CheckParam(rowMapper != null, nameof(mapper), "mapper should implement ISchemaBoundRowMapper");
_bindings = Bindings.Create(data.Schema, rowMapper, args.Suffix);
OutputSchema = Schema.Create(_bindings);
OutputSchema = _bindings.AsSchema;
}

/// <summary>
Expand All @@ -171,7 +171,7 @@ private GenericScorer(IHostEnvironment env, GenericScorer transform, IDataView d
: base(env, data, RegistrationName, transform.Bindable)
{
_bindings = transform._bindings.ApplyToSchema(env, data.Schema);
OutputSchema = Schema.Create(_bindings);
OutputSchema = _bindings.AsSchema;
}

/// <summary>
Expand All @@ -182,7 +182,7 @@ private GenericScorer(IHost host, ModelLoadContext ctx, IDataView input)
{
Contracts.AssertValue(ctx);
_bindings = Bindings.Create(ctx, host, Bindable, input.Schema);
OutputSchema = Schema.Create(_bindings);
OutputSchema = _bindings.AsSchema;
}

/// <summary>
Expand Down
6 changes: 3 additions & 3 deletions src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -307,15 +307,15 @@ private protected PredictedLabelScorerBase(ScorerArgumentsBase args, IHostEnviro
var predColType = getPredColType(scoreType, rowMapper);

Bindings = BindingsImpl.Create(data.Schema, rowMapper, args.Suffix, scoreColKind, scoreColIndex, predColType);
OutputSchema = Schema.Create(Bindings);
OutputSchema = Bindings.AsSchema;
}

protected PredictedLabelScorerBase(IHostEnvironment env, PredictedLabelScorerBase transform,
IDataView newSource, string registrationName)
: base(env, newSource, registrationName, transform.Bindable)
{
Bindings = transform.Bindings.ApplyToSchema(newSource.Schema, Bindable, env);
OutputSchema = Schema.Create(Bindings);
OutputSchema = Bindings.AsSchema;
}

[BestFriend]
Expand All @@ -329,7 +329,7 @@ private protected PredictedLabelScorerBase(IHost host, ModelLoadContext ctx, IDa
Host.AssertValue(getPredColType);

Bindings = BindingsImpl.Create(ctx, input.Schema, host, Bindable, outputTypeMatches, getPredColType);
OutputSchema = Schema.Create(Bindings);
OutputSchema = Bindings.AsSchema;
}

private protected override void SaveCore(ModelSaveContext ctx)
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Scorers/RowToRowScorerBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ private static bool[] GetActive(BindingsBase bindings, Func<int, bool> predicate
Contracts.Assert(active.Length == bindings.ColumnCount);

var activeInput = bindings.GetActiveInput(predicate);
Contracts.Assert(activeInput.Length == bindings.Input.ColumnCount);
Contracts.Assert(activeInput.Length == bindings.Input.Count);

// Get a predicate that determines which Mapper outputs are active.
predicateMapper = bindings.GetActiveMapperColumns(active);
Expand Down
Loading