From bd3f4e1d0cd9d83e9213abcb0ac32aac71c44c4f Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Tue, 19 Feb 2019 20:32:27 -0800 Subject: [PATCH 1/3] Creation of DataViewSchemaAnnotationsExtensions and movement of public facing metadata methods. --- .../Dynamic/NgramExtraction.cs | 4 +- src/Microsoft.ML.Core/Data/MetadataUtils.cs | 55 +---------- .../Data/SchemaAnnotationsExtensions.cs | 92 +++++++++++++++++++ .../EntryPoints/PredictorModelImpl.cs | 2 +- .../Evaluators/EvaluatorUtils.cs | 6 +- .../MultiClassClassifierEvaluator.cs | 11 ++- .../Transforms/InvertHashUtils.cs | 2 +- .../Text/NgramTransform.cs | 4 +- 8 files changed, 111 insertions(+), 65 deletions(-) create mode 100644 src/Microsoft.ML.Data/Data/SchemaAnnotationsExtensions.cs diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/NgramExtraction.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/NgramExtraction.cs index 71d6fb5895..2b2cbee151 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/NgramExtraction.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/NgramExtraction.cs @@ -52,7 +52,7 @@ public static void NgramTransform() }; // Preview of the CharsUnigrams column obtained after processing the input. VBuffer> slotNames = default; - transformedData_onechars.Schema["CharsUnigrams"].Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref slotNames); + transformedData_onechars.Schema["CharsUnigrams"].GetSlotNames(ref slotNames); var charsOneGramColumn = transformedData_onechars.GetColumn>(ml, "CharsUnigrams"); printHelper("CharsUnigrams", charsOneGramColumn, slotNames); @@ -62,7 +62,7 @@ public static void NgramTransform() // 'B' - 0 'e' - 6 's' - 3 't' - 6 '' - 9 'g' - 2 'a' - 2 'm' - 2 'I' - 0 ''' - 0 'v' - 0 ... // Preview of the CharsTwoGrams column obtained after processing the input. var charsTwoGramColumn = transformedData_twochars.GetColumn>(ml, "CharsTwograms"); - transformedData_twochars.Schema["CharsTwograms"].Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref slotNames); + transformedData_twochars.Schema["CharsTwograms"].GetSlotNames(ref slotNames); printHelper("CharsTwograms", charsTwoGramColumn, slotNames); // CharsTwograms column obtained post-transformation. diff --git a/src/Microsoft.ML.Core/Data/MetadataUtils.cs b/src/Microsoft.ML.Core/Data/MetadataUtils.cs index d56d05fec7..d2d83ad86d 100644 --- a/src/Microsoft.ML.Core/Data/MetadataUtils.cs +++ b/src/Microsoft.ML.Core/Data/MetadataUtils.cs @@ -288,17 +288,6 @@ internal static IEnumerable GetColumnSet(this DataViewSchema schema, string } } - /// - /// Returns true if the specified column: - /// * is a vector of length N - /// * has a SlotNames metadata - /// * metadata type is VBuffer<ReadOnlyMemory<char>> of length N - /// - public static bool HasSlotNames(this DataViewSchema.Column column) - => column.Type is VectorType vectorType - && vectorType.Size > 0 - && column.HasSlotNames(vectorType.Size); - /// /// Returns true if the specified column: /// * has a SlotNames metadata @@ -318,12 +307,6 @@ internal static bool HasSlotNames(this DataViewSchema.Column column, int vectorS && vectorType.ItemType is TextDataViewType; } - public static void GetSlotNames(this DataViewSchema.Column column, ref VBuffer> slotNames) - => column.Metadata.GetValue(Kinds.SlotNames, ref slotNames); - - public static void GetKeyValues(this DataViewSchema.Column column, ref VBuffer keyValues) - => column.Metadata.GetValue(Kinds.KeyValues, ref keyValues); - [BestFriend] internal static void GetSlotNames(RoleMappedSchema schema, RoleMappedSchema.ColumnRole role, int vectorSize, ref VBuffer> slotNames) { @@ -337,22 +320,6 @@ internal static void GetSlotNames(RoleMappedSchema schema, RoleMappedSchema.Colu schema.Schema[list[0].Index].Metadata.GetValue(Kinds.SlotNames, ref slotNames); } - [BestFriend] - internal static bool HasKeyValues(this DataViewSchema.Column column, DataViewType type) - { - // False if type is not KeyType because GetKeyCount() returns 0. - ulong keyCount = type.GetKeyCount(); - if (keyCount == 0) - return false; - - var metaColumn = column.Metadata.Schema.GetColumnOrNull(Kinds.KeyValues); - return - metaColumn != null - && metaColumn.Value.Type is VectorType vectorType - && keyCount == (ulong)vectorType.Size - && vectorType.ItemType is TextDataViewType; - } - [BestFriend] internal static bool HasKeyValues(this SchemaShape.Column col) { @@ -361,31 +328,17 @@ internal static bool HasKeyValues(this SchemaShape.Column col) && metaCol.ItemType is TextDataViewType; } - /// - /// Returns true iff has IsNormalized metadata set to true. - /// - public static bool IsNormalized(this DataViewSchema.Column column) - { - var metaColumn = column.Metadata.Schema.GetColumnOrNull((Kinds.IsNormalized)); - if (metaColumn == null || !(metaColumn.Value.Type is BooleanDataViewType)) - return false; - - bool value = default; - column.Metadata.GetValue(Kinds.IsNormalized, ref value); - return value; - } - /// /// Returns whether a column has the metadata indicated by /// the schema shape. /// - /// The schema shape column to query + /// The schema shape column to query /// True if and only if the column has the metadata /// of a scalar type, which we assume, if set, should be true. - public static bool IsNormalized(this SchemaShape.Column col) + public static bool IsNormalized(this SchemaShape.Column column) { - Contracts.CheckParam(col.IsValid, nameof(col), "struct not initialized properly"); - return col.Metadata.TryFindColumn(Kinds.IsNormalized, out var metaCol) + Contracts.CheckParam(column.IsValid, nameof(column), "struct not initialized properly"); + return column.Metadata.TryFindColumn(Kinds.IsNormalized, out var metaCol) && metaCol.Kind == SchemaShape.Column.VectorKind.Scalar && !metaCol.IsKey && metaCol.ItemType == BooleanDataViewType.Instance; } diff --git a/src/Microsoft.ML.Data/Data/SchemaAnnotationsExtensions.cs b/src/Microsoft.ML.Data/Data/SchemaAnnotationsExtensions.cs new file mode 100644 index 0000000000..d1900d1683 --- /dev/null +++ b/src/Microsoft.ML.Data/Data/SchemaAnnotationsExtensions.cs @@ -0,0 +1,92 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using Microsoft.Data.DataView; + +namespace Microsoft.ML.Data +{ + /// + /// Extension methods to facilitate easy consumption of popular contents of . + /// + public static class SchemaAnnotationsExtensions + { + /// + /// Returns if the input column is of , and that has + /// SlotNames metadata of a whose + /// is of , and further whose matches + /// this input vector size. + /// + /// The column whose will be queried. + /// + public static bool HasSlotNames(this DataViewSchema.Column column) + => column.Type is VectorType vectorType + && vectorType.Size > 0 + && column.HasSlotNames(vectorType.Size); + + /// + /// Stores the slots names of the input column into the provided buffer, if there are slot names. + /// Otherwise it will throw an exception. + /// + /// + /// The column whose will be queried. + /// The into which the slot names will be stored. + public static void GetSlotNames(this DataViewSchema.Column column, ref VBuffer> slotNames) + => column.Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref slotNames); + + /// + /// Returns if the input column is of , and that has + /// SlotNames metadata of a whose + /// is of , and further whose matches + /// this input vector size. + /// + /// The column whose will be queried. + /// The type of the individual key-values to query. A common, + /// though not universal, type to provide is , so if left unspecified + /// this will be assumed to have the value . + /// + public static bool HasKeyValues(this DataViewSchema.Column column, PrimitiveDataViewType keyValueItemType = null) + { + // False if type is neither a key type, or a vector of key types. + if (!(column.Type.GetItemType() is KeyType keyType)) + return false; + + if (keyValueItemType == null) + keyValueItemType = TextDataViewType.Instance; + + var metaColumn = column.Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues); + return + metaColumn != null + && metaColumn.Value.Type is VectorType vectorType + && keyType.Count == (ulong)vectorType.Size + && keyValueItemType.Equals(vectorType.ItemType); + } + + /// + /// Stores the key values of the input colum into the provided buffer, if this is of key type and whose + /// key values are of whose matches + /// . If there is no matching key valued metadata this will throw an exception. + /// + /// The type of the key values. + /// The column whose will be queried. + /// The into which the key values will be stored. + public static void GetKeyValues(this DataViewSchema.Column column, ref VBuffer keyValues) + => column.Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref keyValues); + + /// + /// Returns if and only if has IsNormalized metadata + /// set to . + /// + public static bool IsNormalized(this DataViewSchema.Column column) + { + var metaColumn = column.Metadata.Schema.GetColumnOrNull((MetadataUtils.Kinds.IsNormalized)); + if (metaColumn == null || !(metaColumn.Value.Type is BooleanDataViewType)) + return false; + + bool value = default; + column.Metadata.GetValue(MetadataUtils.Kinds.IsNormalized, ref value); + return value; + } + } +} diff --git a/src/Microsoft.ML.Data/EntryPoints/PredictorModelImpl.cs b/src/Microsoft.ML.Data/EntryPoints/PredictorModelImpl.cs index 7ccc55f0f3..e6a44711ae 100644 --- a/src/Microsoft.ML.Data/EntryPoints/PredictorModelImpl.cs +++ b/src/Microsoft.ML.Data/EntryPoints/PredictorModelImpl.cs @@ -127,7 +127,7 @@ internal override string[] GetLabelInfo(IHostEnvironment env, out DataViewType l if (trainRms.Label != null) { labelType = trainRms.Label.Value.Type; - if (trainRms.Label.Value.HasKeyValues(labelType)) + if (trainRms.Label.Value.HasKeyValues()) { VBuffer> keyValues = default; trainRms.Label.Value.GetKeyValues(ref keyValues); diff --git a/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs b/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs index e3d4ed0bde..cca761305d 100644 --- a/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs +++ b/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs @@ -837,7 +837,7 @@ private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, string { if (dvNumber == 0) { - if (dv.Schema[i].HasKeyValues(type.GetItemType())) + if (dv.Schema[i].HasKeyValues()) firstDvVectorKeyColumns.Add(name); // Store the slot names of the 1st idv and use them as baseline. if (dv.Schema[i].HasSlotNames(vectorType.Size)) @@ -866,9 +866,9 @@ private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, string // The label column can be a key. Reconcile the key values, and wrap with a KeyToValue transform. labelColKeyValuesType = dv.Schema[i].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type; } - else if (dvNumber == 0 && dv.Schema[i].HasKeyValues(type)) + else if (dvNumber == 0 && dv.Schema[i].HasKeyValues()) firstDvKeyWithNamesColumns.Add(name); - else if (type.GetKeyCount() > 0 && name != labelColName && !dv.Schema[i].HasKeyValues(type)) + else if (type.GetKeyCount() > 0 && name != labelColName && !dv.Schema[i].HasKeyValues()) { // For any other key column (such as GroupId) we do not reconcile the key values, we only convert to U4. if (!firstDvKeyNoNamesColumns.ContainsKey(name)) diff --git a/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs index bc8b528c1f..7e85bfc3ea 100644 --- a/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs @@ -998,16 +998,17 @@ private protected override IEnumerable GetPerInstanceColumnsToSave(RoleM // Multi-class evaluator adds four per-instance columns: "Assigned", "Top scores", "Top classes" and "Log-loss". private protected override IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMappedSchema schema) { - // If the label column is a key without key values, convert it to I8, just for saving the per-instance + // If the label column is a key without text key values, convert it to I8, just for saving the per-instance // text file, since if there are different key counts the columns cannot be appended. string labelName = schema.Label.Value.Name; - if (!perInst.Schema.TryGetColumnIndex(labelName, out int labelCol)) + if (!perInst.Schema.TryGetColumnIndex(labelName, out int labelColIndex)) throw Host.ExceptSchemaMismatch(nameof(schema), "label", labelName); - var labelType = perInst.Schema[labelCol].Type; - if (labelType is KeyType keyType && (!perInst.Schema[labelCol].HasKeyValues(keyType) || labelType.RawType != typeof(uint))) + var labelCol = perInst.Schema[labelColIndex]; + var labelType = labelCol.Type; + if (labelType is KeyType && (!labelCol.HasKeyValues() || labelType.RawType != typeof(uint))) { perInst = LambdaColumnMapper.Create(Host, "ConvertToDouble", perInst, labelName, - labelName, perInst.Schema[labelCol].Type, NumberDataViewType.Double, + labelName, labelCol.Type, NumberDataViewType.Double, (in uint src, ref double dst) => dst = src == 0 ? double.NaN : src - 1); } diff --git a/src/Microsoft.ML.Data/Transforms/InvertHashUtils.cs b/src/Microsoft.ML.Data/Transforms/InvertHashUtils.cs index f1e9fcbd0a..e7fe5d4008 100644 --- a/src/Microsoft.ML.Data/Transforms/InvertHashUtils.cs +++ b/src/Microsoft.ML.Data/Transforms/InvertHashUtils.cs @@ -47,7 +47,7 @@ public static ValueMapper GetSimpleMapper(DataViewSchema sc bool identity; // Second choice: if key, utilize the KeyValues metadata for that key, if it has one and is text. - if (schema[col].HasKeyValues(keyType)) + if (schema[col].HasKeyValues()) { // REVIEW: Non-textual KeyValues are certainly possible. Should we handle them? // Get the key names. diff --git a/src/Microsoft.ML.Transforms/Text/NgramTransform.cs b/src/Microsoft.ML.Transforms/Text/NgramTransform.cs index 3cadd72dc9..2d97693f88 100644 --- a/src/Microsoft.ML.Transforms/Text/NgramTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/NgramTransform.cs @@ -509,7 +509,7 @@ protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() private void AddMetadata(int iinfo, MetadataBuilder builder) { - if (InputSchema[_srcCols[iinfo]].HasKeyValues(_srcTypes[iinfo].GetItemType())) + if (InputSchema[_srcCols[iinfo]].HasKeyValues()) { ValueGetter>> getter = (ref VBuffer> dst) => { @@ -525,7 +525,7 @@ private void GetSlotNames(int iinfo, int size, ref VBuffer> { var itemType = _srcTypes[iinfo].GetItemType(); Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); - Host.Assert(InputSchema[_srcCols[iinfo]].HasKeyValues(itemType)); + Host.Assert(InputSchema[_srcCols[iinfo]].HasKeyValues()); var unigramNames = new VBuffer>(); From d4c6f4768cbda9c7ab719841ea62e7d37de43bdb Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Tue, 19 Feb 2019 22:17:24 -0800 Subject: [PATCH 2/3] Internalize MetadataUtils and MetadataBuilder extensions, and revert members to public. --- src/Microsoft.ML.Core/Data/MetadataUtils.cs | 60 +++++++------------ .../Transforms/MetadataDispatcher.cs | 5 +- 2 files changed, 24 insertions(+), 41 deletions(-) diff --git a/src/Microsoft.ML.Core/Data/MetadataUtils.cs b/src/Microsoft.ML.Core/Data/MetadataUtils.cs index d2d83ad86d..43afde2d6a 100644 --- a/src/Microsoft.ML.Core/Data/MetadataUtils.cs +++ b/src/Microsoft.ML.Core/Data/MetadataUtils.cs @@ -14,7 +14,8 @@ namespace Microsoft.ML.Data /// /// Utilities for implementing and using the metadata API of . /// - public static class MetadataUtils + [BestFriend] + internal static class MetadataUtils { /// /// This class lists the canonical metadata kinds @@ -114,20 +115,17 @@ public static class ScoreValueKind /// /// Returns a standard exception for responding to an invalid call to GetMetadata. /// - [BestFriend] - internal static Exception ExceptGetMetadata() => Contracts.Except("Invalid call to GetMetadata"); + public static Exception ExceptGetMetadata() => Contracts.Except("Invalid call to GetMetadata"); /// /// Returns a standard exception for responding to an invalid call to GetMetadata. /// - [BestFriend] - internal static Exception ExceptGetMetadata(this IExceptionContext ctx) => ctx.Except("Invalid call to GetMetadata"); + public static Exception ExceptGetMetadata(this IExceptionContext ctx) => ctx.Except("Invalid call to GetMetadata"); /// /// Helper to marshal a call to GetMetadata{TValue} to a specific type. /// - [BestFriend] - internal static void Marshal(this MetadataGetter getter, int col, ref TNeed dst) + public static void Marshal(this MetadataGetter getter, int col, ref TNeed dst) { Contracts.CheckValue(getter, nameof(getter)); @@ -141,8 +139,7 @@ internal static void Marshal(this MetadataGetter getter, in /// Returns a vector type with item type text and the given size. The size must be positive. /// This is a standard type for metadata consisting of multiple text values, eg SlotNames. /// - [BestFriend] - internal static VectorType GetNamesType(int size) + public static VectorType GetNamesType(int size) { Contracts.CheckParam(size > 0, nameof(size), "must be known size"); return new VectorType(TextDataViewType.Instance, size); @@ -154,8 +151,7 @@ internal static VectorType GetNamesType(int size) /// This is a standard type for metadata consisting of multiple int values that represent /// categorical slot ranges with in a column. /// - [BestFriend] - internal static VectorType GetCategoricalType(int rangeCount) + public static VectorType GetCategoricalType(int rangeCount) { Contracts.CheckParam(rangeCount > 0, nameof(rangeCount), "must be known size"); return new VectorType(NumberDataViewType.Int32, rangeCount, 2); @@ -166,8 +162,7 @@ internal static VectorType GetCategoricalType(int rangeCount) /// /// The type of the ScoreColumnSetId metadata. /// - [BestFriend] - internal static KeyType ScoreColumnSetIdType + public static KeyType ScoreColumnSetIdType { get { @@ -180,8 +175,7 @@ internal static KeyType ScoreColumnSetIdType /// /// Returns a key-value pair useful when implementing GetMetadataTypes(col). /// - [BestFriend] - internal static KeyValuePair GetSlotNamesPair(int size) + public static KeyValuePair GetSlotNamesPair(int size) { return GetNamesType(size).GetPair(Kinds.SlotNames); } @@ -190,8 +184,7 @@ internal static KeyValuePair GetSlotNamesPair(int size) /// Returns a key-value pair useful when implementing GetMetadataTypes(col). This assumes /// that the values of the key type are Text. /// - [BestFriend] - internal static KeyValuePair GetKeyNamesPair(int size) + public static KeyValuePair GetKeyNamesPair(int size) { return GetNamesType(size).GetPair(Kinds.KeyValues); } @@ -200,8 +193,7 @@ internal static KeyValuePair GetKeyNamesPair(int size) /// Given a type and metadata kind string, returns a key-value pair. This is useful when /// implementing GetMetadataTypes(col). /// - [BestFriend] - internal static KeyValuePair GetPair(this DataViewType type, string kind) + public static KeyValuePair GetPair(this DataViewType type, string kind) { Contracts.CheckValue(type, nameof(type)); return new KeyValuePair(kind, type); @@ -212,8 +204,7 @@ internal static KeyValuePair GetPair(this DataViewType typ /// /// Prepends a params array to an enumerable. Useful when implementing GetMetadataTypes. /// - [BestFriend] - internal static IEnumerable Prepend(this IEnumerable tail, params T[] head) + public static IEnumerable Prepend(this IEnumerable tail, params T[] head) { return head.Concat(tail); } @@ -252,8 +243,7 @@ public static uint GetMaxMetadataKind(this DataViewSchema schema, out int colMax /// Returns the set of column ids which match the value of specified metadata kind. /// The metadata type should be a KeyType with raw type U4. /// - [BestFriend] - internal static IEnumerable GetColumnSet(this DataViewSchema schema, string metadataKind, uint value) + public static IEnumerable GetColumnSet(this DataViewSchema schema, string metadataKind, uint value) { for (int col = 0; col < schema.Count; col++) { @@ -272,8 +262,7 @@ internal static IEnumerable GetColumnSet(this DataViewSchema schema, string /// Returns the set of column ids which match the value of specified metadata kind. /// The metadata type should be of type text. /// - [BestFriend] - internal static IEnumerable GetColumnSet(this DataViewSchema schema, string metadataKind, string value) + public static IEnumerable GetColumnSet(this DataViewSchema schema, string metadataKind, string value) { for (int col = 0; col < schema.Count; col++) { @@ -293,8 +282,7 @@ internal static IEnumerable GetColumnSet(this DataViewSchema schema, string /// * has a SlotNames metadata /// * metadata type is VBuffer<ReadOnlyMemory<char>> of length . /// - [BestFriend] - internal static bool HasSlotNames(this DataViewSchema.Column column, int vectorSize) + public static bool HasSlotNames(this DataViewSchema.Column column, int vectorSize) { if (vectorSize == 0) return false; @@ -307,8 +295,7 @@ internal static bool HasSlotNames(this DataViewSchema.Column column, int vectorS && vectorType.ItemType is TextDataViewType; } - [BestFriend] - internal static void GetSlotNames(RoleMappedSchema schema, RoleMappedSchema.ColumnRole role, int vectorSize, ref VBuffer> slotNames) + public static void GetSlotNames(RoleMappedSchema schema, RoleMappedSchema.ColumnRole role, int vectorSize, ref VBuffer> slotNames) { Contracts.CheckValueOrNull(schema); Contracts.CheckParam(vectorSize >= 0, nameof(vectorSize)); @@ -320,8 +307,7 @@ internal static void GetSlotNames(RoleMappedSchema schema, RoleMappedSchema.Colu schema.Schema[list[0].Index].Metadata.GetValue(Kinds.SlotNames, ref slotNames); } - [BestFriend] - internal static bool HasKeyValues(this SchemaShape.Column col) + public static bool HasKeyValues(this SchemaShape.Column col) { return col.Metadata.TryFindColumn(Kinds.KeyValues, out var metaCol) && metaCol.Kind == SchemaShape.Column.VectorKind.Vector @@ -369,8 +355,7 @@ public static bool HasSlotNames(this SchemaShape.Column col) /// The column /// The value to return, if successful /// True if the metadata of the right type exists, false otherwise - [BestFriend] - internal static bool TryGetMetadata(this DataViewSchema schema, PrimitiveDataViewType type, string kind, int col, ref T value) + public static bool TryGetMetadata(this DataViewSchema schema, PrimitiveDataViewType type, string kind, int col, ref T value) { Contracts.CheckValue(schema, nameof(schema)); Contracts.CheckValue(type, nameof(type)); @@ -390,8 +375,7 @@ internal static bool TryGetMetadata(this DataViewSchema schema, PrimitiveData /// The way to interpret that is: feature with indices 0, 1, and 2 are one categorical /// Features with indices 3 and 4 are another categorical. Features 5 and 6 don't appear there, so they are not categoricals. /// - [BestFriend] - internal static bool TryGetCategoricalFeatureIndices(DataViewSchema schema, int colIndex, out int[] categoricalFeatures) + public static bool TryGetCategoricalFeatureIndices(DataViewSchema schema, int colIndex, out int[] categoricalFeatures) { Contracts.CheckValue(schema, nameof(schema)); Contracts.Check(colIndex >= 0, nameof(colIndex)); @@ -438,8 +422,7 @@ internal static bool TryGetCategoricalFeatureIndices(DataViewSchema schema, int /// Produces sequence of columns that are generated by trainer estimators. /// /// whether we should also append 'IsNormalized' (typically for probability column) - [BestFriend] - internal static IEnumerable GetTrainerOutputMetadata(bool isNormalized = false) + public static IEnumerable GetTrainerOutputMetadata(bool isNormalized = false) { var cols = new List(); cols.Add(new SchemaShape.Column(Kinds.ScoreColumnSetId, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true)); @@ -455,8 +438,7 @@ internal static bool TryGetCategoricalFeatureIndices(DataViewSchema schema, int /// If input LabelColumn is not available it produces slotnames metadata by default. /// /// Label column. - [BestFriend] - internal static IEnumerable MetadataForMulticlassScoreColumn(SchemaShape.Column? labelColumn = null) + public static IEnumerable MetadataForMulticlassScoreColumn(SchemaShape.Column? labelColumn = null) { var cols = new List(); if (labelColumn != null && labelColumn.Value.IsKey && HasKeyValues(labelColumn.Value)) diff --git a/src/Microsoft.ML.Data/Transforms/MetadataDispatcher.cs b/src/Microsoft.ML.Data/Transforms/MetadataDispatcher.cs index f7f184ad49..f6767c08a9 100644 --- a/src/Microsoft.ML.Data/Transforms/MetadataDispatcher.cs +++ b/src/Microsoft.ML.Data/Transforms/MetadataDispatcher.cs @@ -13,7 +13,7 @@ namespace Microsoft.ML.Data /// /// Base class for handling the schema metadata API. /// - public abstract class MetadataDispatcherBase + internal abstract class MetadataDispatcherBase { private bool _sealed; @@ -304,7 +304,8 @@ public void GetMetadata(IExceptionContext ectx, string kind, int index, /// a builder for a particular column. Wrap the return in a using statement. Disposing the builder /// records the metadata for the column. Call Seal() once all metadata is constructed. /// - public sealed class MetadataDispatcher : MetadataDispatcherBase + [BestFriend] + internal sealed class MetadataDispatcher : MetadataDispatcherBase { public MetadataDispatcher(int colCount) : base(colCount) From 2636042b522511adf1bdf25673970dfe2eaf645d Mon Sep 17 00:00:00 2001 From: Tom Finley Date: Tue, 19 Feb 2019 23:32:33 -0800 Subject: [PATCH 3/3] Internalize MetadataBuilderExtensions --- src/Microsoft.ML.Core/Data/MetadataBuilderExtensions.cs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Microsoft.ML.Core/Data/MetadataBuilderExtensions.cs b/src/Microsoft.ML.Core/Data/MetadataBuilderExtensions.cs index 478e9f9877..1392df0341 100644 --- a/src/Microsoft.ML.Core/Data/MetadataBuilderExtensions.cs +++ b/src/Microsoft.ML.Core/Data/MetadataBuilderExtensions.cs @@ -7,7 +7,8 @@ namespace Microsoft.ML.Data { - public static class MetadataBuilderExtensions + [BestFriend] + internal static class MetadataBuilderExtensions { /// /// Add slot names metadata.