Skip to content

Metadata utils internalization, migration of few useful methods #2651

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/samples/Microsoft.ML.Samples/Dynamic/NgramExtraction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public static void NgramTransform()
};
// Preview of the CharsUnigrams column obtained after processing the input.
VBuffer<ReadOnlyMemory<char>> slotNames = default;
transformedData_onechars.Schema["CharsUnigrams"].Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref slotNames);
transformedData_onechars.Schema["CharsUnigrams"].GetSlotNames(ref slotNames);
var charsOneGramColumn = transformedData_onechars.GetColumn<VBuffer<float>>(ml, "CharsUnigrams");
printHelper("CharsUnigrams", charsOneGramColumn, slotNames);

Expand All @@ -62,7 +62,7 @@ public static void NgramTransform()
// 'B' - 0 'e' - 6 's' - 3 't' - 6 '<?>' - 9 'g' - 2 'a' - 2 'm' - 2 'I' - 0 ''' - 0 'v' - 0 ...
// Preview of the CharsTwoGrams column obtained after processing the input.
var charsTwoGramColumn = transformedData_twochars.GetColumn<VBuffer<float>>(ml, "CharsTwograms");
transformedData_twochars.Schema["CharsTwograms"].Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref slotNames);
transformedData_twochars.Schema["CharsTwograms"].GetSlotNames(ref slotNames);
printHelper("CharsTwograms", charsTwoGramColumn, slotNames);

// CharsTwograms column obtained post-transformation.
Expand Down
3 changes: 2 additions & 1 deletion src/Microsoft.ML.Core/Data/MetadataBuilderExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

namespace Microsoft.ML.Data
{
public static class MetadataBuilderExtensions
[BestFriend]
internal static class MetadataBuilderExtensions
{
/// <summary>
/// Add slot names metadata.
Expand Down
115 changes: 25 additions & 90 deletions src/Microsoft.ML.Core/Data/MetadataUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ namespace Microsoft.ML.Data
/// <summary>
/// Utilities for implementing and using the metadata API of <see cref="DataViewSchema"/>.
/// </summary>
public static class MetadataUtils
[BestFriend]
internal static class MetadataUtils
{
/// <summary>
/// This class lists the canonical metadata kinds
Expand Down Expand Up @@ -114,20 +115,17 @@ public static class ScoreValueKind
/// <summary>
/// Returns a standard exception for responding to an invalid call to GetMetadata.
/// </summary>
[BestFriend]
internal static Exception ExceptGetMetadata() => Contracts.Except("Invalid call to GetMetadata");
public static Exception ExceptGetMetadata() => Contracts.Except("Invalid call to GetMetadata");

/// <summary>
/// Returns a standard exception for responding to an invalid call to GetMetadata.
/// </summary>
[BestFriend]
internal static Exception ExceptGetMetadata(this IExceptionContext ctx) => ctx.Except("Invalid call to GetMetadata");
public static Exception ExceptGetMetadata(this IExceptionContext ctx) => ctx.Except("Invalid call to GetMetadata");

/// <summary>
/// Helper to marshal a call to GetMetadata{TValue} to a specific type.
/// </summary>
[BestFriend]
internal static void Marshal<THave, TNeed>(this MetadataGetter<THave> getter, int col, ref TNeed dst)
public static void Marshal<THave, TNeed>(this MetadataGetter<THave> getter, int col, ref TNeed dst)
{
Contracts.CheckValue(getter, nameof(getter));

Expand All @@ -141,8 +139,7 @@ internal static void Marshal<THave, TNeed>(this MetadataGetter<THave> getter, in
/// Returns a vector type with item type text and the given size. The size must be positive.
/// This is a standard type for metadata consisting of multiple text values, eg SlotNames.
/// </summary>
[BestFriend]
internal static VectorType GetNamesType(int size)
public static VectorType GetNamesType(int size)
{
Contracts.CheckParam(size > 0, nameof(size), "must be known size");
return new VectorType(TextDataViewType.Instance, size);
Expand All @@ -154,8 +151,7 @@ internal static VectorType GetNamesType(int size)
/// This is a standard type for metadata consisting of multiple int values that represent
/// categorical slot ranges with in a column.
/// </summary>
[BestFriend]
internal static VectorType GetCategoricalType(int rangeCount)
public static VectorType GetCategoricalType(int rangeCount)
{
Contracts.CheckParam(rangeCount > 0, nameof(rangeCount), "must be known size");
return new VectorType(NumberDataViewType.Int32, rangeCount, 2);
Expand All @@ -166,8 +162,7 @@ internal static VectorType GetCategoricalType(int rangeCount)
/// <summary>
/// The type of the ScoreColumnSetId metadata.
/// </summary>
[BestFriend]
internal static KeyType ScoreColumnSetIdType
public static KeyType ScoreColumnSetIdType
{
get
{
Expand All @@ -180,8 +175,7 @@ internal static KeyType ScoreColumnSetIdType
/// <summary>
/// Returns a key-value pair useful when implementing GetMetadataTypes(col).
/// </summary>
[BestFriend]
internal static KeyValuePair<string, DataViewType> GetSlotNamesPair(int size)
public static KeyValuePair<string, DataViewType> GetSlotNamesPair(int size)
{
return GetNamesType(size).GetPair(Kinds.SlotNames);
}
Expand All @@ -190,8 +184,7 @@ internal static KeyValuePair<string, DataViewType> GetSlotNamesPair(int size)
/// Returns a key-value pair useful when implementing GetMetadataTypes(col). This assumes
/// that the values of the key type are Text.
/// </summary>
[BestFriend]
internal static KeyValuePair<string, DataViewType> GetKeyNamesPair(int size)
public static KeyValuePair<string, DataViewType> GetKeyNamesPair(int size)
{
return GetNamesType(size).GetPair(Kinds.KeyValues);
}
Expand All @@ -200,8 +193,7 @@ internal static KeyValuePair<string, DataViewType> GetKeyNamesPair(int size)
/// Given a type and metadata kind string, returns a key-value pair. This is useful when
/// implementing GetMetadataTypes(col).
/// </summary>
[BestFriend]
internal static KeyValuePair<string, DataViewType> GetPair(this DataViewType type, string kind)
public static KeyValuePair<string, DataViewType> GetPair(this DataViewType type, string kind)
{
Contracts.CheckValue(type, nameof(type));
return new KeyValuePair<string, DataViewType>(kind, type);
Expand All @@ -212,8 +204,7 @@ internal static KeyValuePair<string, DataViewType> GetPair(this DataViewType typ
/// <summary>
/// Prepends a params array to an enumerable. Useful when implementing GetMetadataTypes.
/// </summary>
[BestFriend]
internal static IEnumerable<T> Prepend<T>(this IEnumerable<T> tail, params T[] head)
public static IEnumerable<T> Prepend<T>(this IEnumerable<T> tail, params T[] head)
{
return head.Concat(tail);
}
Expand Down Expand Up @@ -252,8 +243,7 @@ public static uint GetMaxMetadataKind(this DataViewSchema schema, out int colMax
/// Returns the set of column ids which match the value of specified metadata kind.
/// The metadata type should be a KeyType with raw type U4.
/// </summary>
[BestFriend]
internal static IEnumerable<int> GetColumnSet(this DataViewSchema schema, string metadataKind, uint value)
public static IEnumerable<int> GetColumnSet(this DataViewSchema schema, string metadataKind, uint value)
{
for (int col = 0; col < schema.Count; col++)
{
Expand All @@ -272,8 +262,7 @@ internal static IEnumerable<int> GetColumnSet(this DataViewSchema schema, string
/// Returns the set of column ids which match the value of specified metadata kind.
/// The metadata type should be of type text.
/// </summary>
[BestFriend]
internal static IEnumerable<int> GetColumnSet(this DataViewSchema schema, string metadataKind, string value)
public static IEnumerable<int> GetColumnSet(this DataViewSchema schema, string metadataKind, string value)
{
for (int col = 0; col < schema.Count; col++)
{
Expand All @@ -288,24 +277,12 @@ internal static IEnumerable<int> GetColumnSet(this DataViewSchema schema, string
}
}

/// <summary>
/// Returns <c>true</c> if the specified column:
/// * is a vector of length N
/// * has a SlotNames metadata
/// * metadata type is VBuffer&lt;ReadOnlyMemory&lt;char&gt;&gt; of length N
/// </summary>
public static bool HasSlotNames(this DataViewSchema.Column column)
=> column.Type is VectorType vectorType
&& vectorType.Size > 0
&& column.HasSlotNames(vectorType.Size);

/// <summary>
/// Returns <c>true</c> if the specified column:
/// * has a SlotNames metadata
/// * metadata type is VBuffer&lt;ReadOnlyMemory&lt;char&gt;&gt; of length <paramref name="vectorSize"/>.
/// </summary>
[BestFriend]
internal static bool HasSlotNames(this DataViewSchema.Column column, int vectorSize)
public static bool HasSlotNames(this DataViewSchema.Column column, int vectorSize)
{
if (vectorSize == 0)
return false;
Expand All @@ -318,14 +295,7 @@ internal static bool HasSlotNames(this DataViewSchema.Column column, int vectorS
&& vectorType.ItemType is TextDataViewType;
}

public static void GetSlotNames(this DataViewSchema.Column column, ref VBuffer<ReadOnlyMemory<char>> slotNames)
=> column.Metadata.GetValue(Kinds.SlotNames, ref slotNames);

public static void GetKeyValues<TValue>(this DataViewSchema.Column column, ref VBuffer<TValue> keyValues)
=> column.Metadata.GetValue(Kinds.KeyValues, ref keyValues);

[BestFriend]
internal static void GetSlotNames(RoleMappedSchema schema, RoleMappedSchema.ColumnRole role, int vectorSize, ref VBuffer<ReadOnlyMemory<char>> slotNames)
public static void GetSlotNames(RoleMappedSchema schema, RoleMappedSchema.ColumnRole role, int vectorSize, ref VBuffer<ReadOnlyMemory<char>> slotNames)
{
Contracts.CheckValueOrNull(schema);
Contracts.CheckParam(vectorSize >= 0, nameof(vectorSize));
Expand All @@ -337,55 +307,24 @@ internal static void GetSlotNames(RoleMappedSchema schema, RoleMappedSchema.Colu
schema.Schema[list[0].Index].Metadata.GetValue(Kinds.SlotNames, ref slotNames);
}

[BestFriend]
internal static bool HasKeyValues(this DataViewSchema.Column column, DataViewType type)
{
// False if type is not KeyType because GetKeyCount() returns 0.
ulong keyCount = type.GetKeyCount();
if (keyCount == 0)
return false;

var metaColumn = column.Metadata.Schema.GetColumnOrNull(Kinds.KeyValues);
return
metaColumn != null
&& metaColumn.Value.Type is VectorType vectorType
&& keyCount == (ulong)vectorType.Size
&& vectorType.ItemType is TextDataViewType;
}

[BestFriend]
internal static bool HasKeyValues(this SchemaShape.Column col)
public static bool HasKeyValues(this SchemaShape.Column col)
Copy link
Member

@eerhardt eerhardt Feb 20, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my information - what's the difference between this method and the public static bool HasKeyValues in SchemaAnnotationsExtensions?

I can see they contain different code, but they are named the same...?

{
return col.Metadata.TryFindColumn(Kinds.KeyValues, out var metaCol)
&& metaCol.Kind == SchemaShape.Column.VectorKind.Vector
&& metaCol.ItemType is TextDataViewType;
}

/// <summary>
/// Returns true iff <paramref name="column"/> has IsNormalized metadata set to true.
/// </summary>
public static bool IsNormalized(this DataViewSchema.Column column)
{
var metaColumn = column.Metadata.Schema.GetColumnOrNull((Kinds.IsNormalized));
if (metaColumn == null || !(metaColumn.Value.Type is BooleanDataViewType))
return false;

bool value = default;
column.Metadata.GetValue(Kinds.IsNormalized, ref value);
return value;
}

/// <summary>
/// Returns whether a column has the <see cref="Kinds.IsNormalized"/> metadata indicated by
/// the schema shape.
/// </summary>
/// <param name="col">The schema shape column to query</param>
/// <param name="column">The schema shape column to query</param>
/// <returns>True if and only if the column has the <see cref="Kinds.IsNormalized"/> metadata
/// of a scalar <see cref="BooleanDataViewType"/> type, which we assume, if set, should be <c>true</c>.</returns>
public static bool IsNormalized(this SchemaShape.Column col)
public static bool IsNormalized(this SchemaShape.Column column)
{
Contracts.CheckParam(col.IsValid, nameof(col), "struct not initialized properly");
return col.Metadata.TryFindColumn(Kinds.IsNormalized, out var metaCol)
Contracts.CheckParam(column.IsValid, nameof(column), "struct not initialized properly");
return column.Metadata.TryFindColumn(Kinds.IsNormalized, out var metaCol)
&& metaCol.Kind == SchemaShape.Column.VectorKind.Scalar && !metaCol.IsKey
&& metaCol.ItemType == BooleanDataViewType.Instance;
}
Expand Down Expand Up @@ -416,8 +355,7 @@ public static bool HasSlotNames(this SchemaShape.Column col)
/// <param name="col">The column</param>
/// <param name="value">The value to return, if successful</param>
/// <returns>True if the metadata of the right type exists, false otherwise</returns>
[BestFriend]
internal static bool TryGetMetadata<T>(this DataViewSchema schema, PrimitiveDataViewType type, string kind, int col, ref T value)
public static bool TryGetMetadata<T>(this DataViewSchema schema, PrimitiveDataViewType type, string kind, int col, ref T value)
{
Contracts.CheckValue(schema, nameof(schema));
Contracts.CheckValue(type, nameof(type));
Expand All @@ -437,8 +375,7 @@ internal static bool TryGetMetadata<T>(this DataViewSchema schema, PrimitiveData
/// The way to interpret that is: feature with indices 0, 1, and 2 are one categorical
/// Features with indices 3 and 4 are another categorical. Features 5 and 6 don't appear there, so they are not categoricals.
/// </summary>
[BestFriend]
internal static bool TryGetCategoricalFeatureIndices(DataViewSchema schema, int colIndex, out int[] categoricalFeatures)
public static bool TryGetCategoricalFeatureIndices(DataViewSchema schema, int colIndex, out int[] categoricalFeatures)
{
Contracts.CheckValue(schema, nameof(schema));
Contracts.Check(colIndex >= 0, nameof(colIndex));
Expand Down Expand Up @@ -485,8 +422,7 @@ internal static bool TryGetCategoricalFeatureIndices(DataViewSchema schema, int
/// Produces sequence of columns that are generated by trainer estimators.
/// </summary>
/// <param name="isNormalized">whether we should also append 'IsNormalized' (typically for probability column)</param>
[BestFriend]
internal static IEnumerable<SchemaShape.Column> GetTrainerOutputMetadata(bool isNormalized = false)
public static IEnumerable<SchemaShape.Column> GetTrainerOutputMetadata(bool isNormalized = false)
{
var cols = new List<SchemaShape.Column>();
cols.Add(new SchemaShape.Column(Kinds.ScoreColumnSetId, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true));
Expand All @@ -502,8 +438,7 @@ internal static bool TryGetCategoricalFeatureIndices(DataViewSchema schema, int
/// If input LabelColumn is not available it produces slotnames metadata by default.
/// </summary>
/// <param name="labelColumn">Label column.</param>
[BestFriend]
internal static IEnumerable<SchemaShape.Column> MetadataForMulticlassScoreColumn(SchemaShape.Column? labelColumn = null)
public static IEnumerable<SchemaShape.Column> MetadataForMulticlassScoreColumn(SchemaShape.Column? labelColumn = null)
{
var cols = new List<SchemaShape.Column>();
if (labelColumn != null && labelColumn.Value.IsKey && HasKeyValues(labelColumn.Value))
Expand Down
Loading