Skip to content

Commit 45b6c43

Browse files
eerhardtTomFinley
authored andcommitted
Move KeyType, VectorType and VBuffer to ML.DataView (#3022)
* Move KeyType into ML.DataView assembly. * Rename KeyType to KeyDataViewType. * Move VBuffer to ML.DataView * Move VectorType to ML.DataView * Rename VectorType to VectorDataViewType.
1 parent b241bcd commit 45b6c43

File tree

175 files changed

+1334
-1168
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

175 files changed

+1334
-1168
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/TensorFlow/TextClassification.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ public static void Example()
4646
// - Use it for prediction in the pipeline.
4747
var tensorFlowModel = mlContext.Model.LoadTensorFlowModel(modelLocation);
4848
var schema = tensorFlowModel.GetModelSchema();
49-
var featuresType = (VectorType)schema["Features"].Type;
49+
var featuresType = (VectorDataViewType)schema["Features"].Type;
5050
Console.WriteLine("Name: {0}, Type: {1}, Shape: (-1, {2})", "Features", featuresType.ItemType.RawType, featuresType.Dimensions[0]);
51-
var predictionType = (VectorType)schema["Prediction/Softmax"].Type;
51+
var predictionType = (VectorDataViewType)schema["Prediction/Softmax"].Type;
5252
Console.WriteLine("Name: {0}, Type: {1}, Shape: (-1, {2})", "Prediction/Softmax", predictionType.ItemType.RawType, predictionType.Dimensions[0]);
5353

5454
// The model expects the input feature vector to be a fixed length vector.

src/Microsoft.ML.Core/Data/AnnotationBuilderExtensions.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ internal static class AnnotationBuilderExtensions
1616
/// <param name="size">The size of the slot names vector.</param>
1717
/// <param name="getter">The getter delegate for the slot names.</param>
1818
public static void AddSlotNames(this DataViewSchema.Annotations.Builder builder, int size, ValueGetter<VBuffer<ReadOnlyMemory<char>>> getter)
19-
=> builder.Add(AnnotationUtils.Kinds.SlotNames, new VectorType(TextDataViewType.Instance, size), getter);
19+
=> builder.Add(AnnotationUtils.Kinds.SlotNames, new VectorDataViewType(TextDataViewType.Instance, size), getter);
2020

2121
/// <summary>
2222
/// Add key values annotation.
@@ -27,6 +27,6 @@ public static void AddSlotNames(this DataViewSchema.Annotations.Builder builder,
2727
/// <param name="valueType">The value type of key values. Its raw type must match <typeparamref name="TValue"/>.</param>
2828
/// <param name="getter">The getter delegate for the key values.</param>
2929
public static void AddKeyValues<TValue>(this DataViewSchema.Annotations.Builder builder, int size, PrimitiveDataViewType valueType, ValueGetter<VBuffer<TValue>> getter)
30-
=> builder.Add(AnnotationUtils.Kinds.KeyValues, new VectorType(valueType, size), getter);
30+
=> builder.Add(AnnotationUtils.Kinds.KeyValues, new VectorDataViewType(valueType, size), getter);
3131
}
3232
}

src/Microsoft.ML.Core/Data/AnnotationUtils.cs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ public static class Kinds
3636
public const string KeyValues = "KeyValues";
3737

3838
/// <summary>
39-
/// Annotation kind for sets of score columns. The value is typically a KeyType with raw type U4.
39+
/// Annotation kind for sets of score columns. The value is typically a <see cref="KeyDataViewType"/> with raw type U4.
4040
/// </summary>
4141
public const string ScoreColumnSetId = "ScoreColumnSetId";
4242

@@ -139,10 +139,10 @@ public static void Marshal<THave, TNeed>(this AnnotationGetter<THave> getter, in
139139
/// Returns a vector type with item type text and the given size. The size must be positive.
140140
/// This is a standard type for annotation consisting of multiple text values, eg SlotNames.
141141
/// </summary>
142-
public static VectorType GetNamesType(int size)
142+
public static VectorDataViewType GetNamesType(int size)
143143
{
144144
Contracts.CheckParam(size > 0, nameof(size), "must be known size");
145-
return new VectorType(TextDataViewType.Instance, size);
145+
return new VectorDataViewType(TextDataViewType.Instance, size);
146146
}
147147

148148
/// <summary>
@@ -151,23 +151,23 @@ public static VectorType GetNamesType(int size)
151151
/// This is a standard type for annotation consisting of multiple int values that represent
152152
/// categorical slot ranges with in a column.
153153
/// </summary>
154-
public static VectorType GetCategoricalType(int rangeCount)
154+
public static VectorDataViewType GetCategoricalType(int rangeCount)
155155
{
156156
Contracts.CheckParam(rangeCount > 0, nameof(rangeCount), "must be known size");
157-
return new VectorType(NumberDataViewType.Int32, rangeCount, 2);
157+
return new VectorDataViewType(NumberDataViewType.Int32, rangeCount, 2);
158158
}
159159

160-
private static volatile KeyType _scoreColumnSetIdType;
160+
private static volatile KeyDataViewType _scoreColumnSetIdType;
161161

162162
/// <summary>
163163
/// The type of the ScoreColumnSetId annotation.
164164
/// </summary>
165-
public static KeyType ScoreColumnSetIdType
165+
public static KeyDataViewType ScoreColumnSetIdType
166166
{
167167
get
168168
{
169169
return _scoreColumnSetIdType ??
170-
Interlocked.CompareExchange(ref _scoreColumnSetIdType, new KeyType(typeof(uint), int.MaxValue), null) ??
170+
Interlocked.CompareExchange(ref _scoreColumnSetIdType, new KeyDataViewType(typeof(uint), int.MaxValue), null) ??
171171
_scoreColumnSetIdType;
172172
}
173173
}
@@ -211,7 +211,7 @@ public static IEnumerable<T> Prepend<T>(this IEnumerable<T> tail, params T[] hea
211211

212212
/// <summary>
213213
/// Returns the max value for the specified annotation kind.
214-
/// The annotation type should be a KeyType with raw type U4.
214+
/// The annotation type should be a <see cref="KeyDataViewType"/> with raw type U4.
215215
/// colMax will be set to the first column that has the max value for the specified annotation.
216216
/// If no column has the specified annotation, colMax is set to -1 and the method returns zero.
217217
/// The filter function is called for each column, passing in the schema and the column index, and returns
@@ -224,7 +224,7 @@ public static uint GetMaxAnnotationKind(this DataViewSchema schema, out int colM
224224
for (int col = 0; col < schema.Count; col++)
225225
{
226226
var columnType = schema[col].Annotations.Schema.GetColumnOrNull(annotationKind)?.Type;
227-
if (!(columnType is KeyType) || columnType.RawType != typeof(uint))
227+
if (!(columnType is KeyDataViewType) || columnType.RawType != typeof(uint))
228228
continue;
229229
if (filterFunc != null && !filterFunc(schema, col))
230230
continue;
@@ -241,14 +241,14 @@ public static uint GetMaxAnnotationKind(this DataViewSchema schema, out int colM
241241

242242
/// <summary>
243243
/// Returns the set of column ids which match the value of specified annotation kind.
244-
/// The annotation type should be a KeyType with raw type U4.
244+
/// The annotation type should be a <see cref="KeyDataViewType"/> with raw type U4.
245245
/// </summary>
246246
public static IEnumerable<int> GetColumnSet(this DataViewSchema schema, string annotationKind, uint value)
247247
{
248248
for (int col = 0; col < schema.Count; col++)
249249
{
250250
var columnType = schema[col].Annotations.Schema.GetColumnOrNull(annotationKind)?.Type;
251-
if (columnType is KeyType && columnType.RawType == typeof(uint))
251+
if (columnType is KeyDataViewType && columnType.RawType == typeof(uint))
252252
{
253253
uint val = 0;
254254
schema[col].Annotations.GetValue(annotationKind, ref val);
@@ -290,7 +290,7 @@ public static bool HasSlotNames(this DataViewSchema.Column column, int vectorSiz
290290
var metaColumn = column.Annotations.Schema.GetColumnOrNull(Kinds.SlotNames);
291291
return
292292
metaColumn != null
293-
&& metaColumn.Value.Type is VectorType vectorType
293+
&& metaColumn.Value.Type is VectorDataViewType vectorType
294294
&& vectorType.Size == vectorSize
295295
&& vectorType.ItemType is TextDataViewType;
296296
}
@@ -382,7 +382,7 @@ public static bool TryGetCategoricalFeatureIndices(DataViewSchema schema, int co
382382

383383
bool isValid = false;
384384
categoricalFeatures = null;
385-
if (!(schema[colIndex].Type is VectorType vecType && vecType.Size > 0))
385+
if (!(schema[colIndex].Type is VectorDataViewType vecType && vecType.Size > 0))
386386
return isValid;
387387

388388
var type = schema[colIndex].Annotations.Schema.GetColumnOrNull(Kinds.CategoricalSlotRanges)?.Type;

src/Microsoft.ML.Core/Data/ColumnTypeExtensions.cs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ internal static class ColumnTypeExtensions
1515
{
1616
/// <summary>
1717
/// Whether this type is a standard scalar type completely determined by its <see cref="DataViewType.RawType"/>
18-
/// (not a <see cref="KeyType"/> or <see cref="StructuredDataViewType"/>, etc).
18+
/// (not a <see cref="KeyDataViewType"/> or <see cref="StructuredDataViewType"/>, etc).
1919
/// </summary>
2020
public static bool IsStandardScalar(this DataViewType columnType) =>
2121
(columnType is NumberDataViewType) || (columnType is TextDataViewType) || (columnType is BooleanDataViewType) ||
@@ -25,7 +25,7 @@ public static bool IsStandardScalar(this DataViewType columnType) =>
2525
/// <summary>
2626
/// Zero return means it's not a key type.
2727
/// </summary>
28-
public static ulong GetKeyCount(this DataViewType columnType) => (columnType as KeyType)?.Count ?? 0;
28+
public static ulong GetKeyCount(this DataViewType columnType) => (columnType as KeyDataViewType)?.Count ?? 0;
2929

3030
/// <summary>
3131
/// Sometimes it is necessary to cast the Count to an int. This performs overflow check.
@@ -34,26 +34,26 @@ public static bool IsStandardScalar(this DataViewType columnType) =>
3434
public static int GetKeyCountAsInt32(this DataViewType columnType, IExceptionContext ectx = null)
3535
{
3636
ulong count = columnType.GetKeyCount();
37-
ectx.Check(count <= int.MaxValue, nameof(KeyType) + "." + nameof(KeyType.Count) + " exceeds int.MaxValue.");
37+
ectx.Check(count <= int.MaxValue, nameof(KeyDataViewType) + "." + nameof(KeyDataViewType.Count) + " exceeds int.MaxValue.");
3838
return (int)count;
3939
}
4040

4141
/// <summary>
4242
/// For non-vector types, this returns the column type itself (i.e., return <paramref name="columnType"/>).
4343
/// For vector types, this returns the type of the items stored as values in vector.
4444
/// </summary>
45-
public static DataViewType GetItemType(this DataViewType columnType) => (columnType as VectorType)?.ItemType ?? columnType;
45+
public static DataViewType GetItemType(this DataViewType columnType) => (columnType as VectorDataViewType)?.ItemType ?? columnType;
4646

4747
/// <summary>
4848
/// Zero return means either it's not a vector or the size is unknown.
4949
/// </summary>
50-
public static int GetVectorSize(this DataViewType columnType) => (columnType as VectorType)?.Size ?? 0;
50+
public static int GetVectorSize(this DataViewType columnType) => (columnType as VectorDataViewType)?.Size ?? 0;
5151

5252
/// <summary>
5353
/// For non-vectors, this returns one. For unknown size vectors, it returns zero.
5454
/// For known sized vectors, it returns size.
5555
/// </summary>
56-
public static int GetValueCount(this DataViewType columnType) => (columnType as VectorType)?.Size ?? 1;
56+
public static int GetValueCount(this DataViewType columnType) => (columnType as VectorDataViewType)?.Size ?? 1;
5757

5858
/// <summary>
5959
/// Whether this is a vector type with known size. Returns false for non-vector types.
@@ -85,7 +85,7 @@ public static bool SameSizeAndItemType(this DataViewType columnType, DataViewTyp
8585
return true;
8686

8787
// For vector types, we don't care about the factoring of the dimensions.
88-
if (!(columnType is VectorType vectorType) || !(other is VectorType otherVectorType))
88+
if (!(columnType is VectorDataViewType vectorType) || !(other is VectorDataViewType otherVectorType))
8989
return false;
9090
if (!vectorType.ItemType.Equals(otherVectorType.ItemType))
9191
return false;

src/Microsoft.ML.Core/Data/IEstimator.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@ internal Column(string name, VectorKind vecKind, DataViewType itemType, bool isK
6363
{
6464
Contracts.CheckNonEmpty(name, nameof(name));
6565
Contracts.CheckValueOrNull(annotations);
66-
Contracts.CheckParam(!(itemType is KeyType), nameof(itemType), "Item type cannot be a key");
67-
Contracts.CheckParam(!(itemType is VectorType), nameof(itemType), "Item type cannot be a vector");
68-
Contracts.CheckParam(!isKey || KeyType.IsValidDataType(itemType.RawType), nameof(itemType), "The item type must be valid for a key");
66+
Contracts.CheckParam(!(itemType is KeyDataViewType), nameof(itemType), "Item type cannot be a key");
67+
Contracts.CheckParam(!(itemType is VectorDataViewType), nameof(itemType), "Item type cannot be a vector");
68+
Contracts.CheckParam(!isKey || KeyDataViewType.IsValidDataType(itemType.RawType), nameof(itemType), "The item type must be valid for a key");
6969

7070
Name = name;
7171
Kind = vecKind;
@@ -147,7 +147,7 @@ internal static void GetColumnTypeShape(DataViewType type,
147147
out DataViewType itemType,
148148
out bool isKey)
149149
{
150-
if (type is VectorType vectorType)
150+
if (type is VectorDataViewType vectorType)
151151
{
152152
if (vectorType.IsKnownSize)
153153
{
@@ -166,7 +166,7 @@ internal static void GetColumnTypeShape(DataViewType type,
166166
itemType = type;
167167
}
168168

169-
isKey = itemType is KeyType;
169+
isKey = itemType is KeyDataViewType;
170170
if (isKey)
171171
itemType = ColumnTypeExtensions.PrimitiveTypeFromType(itemType.RawType);
172172
}

src/Microsoft.ML.Core/Data/KeyTypeExtensions.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,17 @@
77
namespace Microsoft.ML.Data
88
{
99
/// <summary>
10-
/// Extension methods related to the KeyType class.
10+
/// Extension methods related to the <see cref="KeyDataViewType"/> class.
1111
/// </summary>
1212
[BestFriend]
1313
internal static class KeyTypeExtensions
1414
{
1515
/// <summary>
1616
/// Sometimes it is necessary to cast the Count to an int. This performs overflow check.
1717
/// </summary>
18-
public static int GetCountAsInt32(this KeyType key, IExceptionContext ectx = null)
18+
public static int GetCountAsInt32(this KeyDataViewType key, IExceptionContext ectx = null)
1919
{
20-
ectx.Check(key.Count <= int.MaxValue, nameof(KeyType) + "." + nameof(KeyType.Count) + " exceeds int.MaxValue.");
20+
ectx.Check(key.Count <= int.MaxValue, nameof(KeyDataViewType) + "." + nameof(KeyDataViewType.Count) + " exceeds int.MaxValue.");
2121
return (int)key.Count;
2222
}
2323
}
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Diagnostics;
7+
8+
namespace Microsoft.ML.Internal.Utilities
9+
{
10+
internal static partial class ArrayUtils
11+
{
12+
// Maximum size of one-dimensional array.
13+
// See: https://msdn.microsoft.com/en-us/library/hh285054(v=vs.110).aspx
14+
public const int ArrayMaxSize = 0X7FEFFFFF;
15+
16+
public static int Size<T>(T[] x)
17+
{
18+
return x == null ? 0 : x.Length;
19+
}
20+
21+
/// <summary>
22+
/// Akin to <c>FindIndexSorted</c>, except stores the found index in the output
23+
/// <c>index</c> parameter, and returns whether that index is a valid index
24+
/// pointing to a value equal to the input parameter <c>value</c>.
25+
/// </summary>
26+
public static bool TryFindIndexSorted(int[] input, int min, int lim, int value, out int index)
27+
{
28+
index = FindIndexSorted(input, min, lim, value);
29+
return index < lim && input[index] == value;
30+
}
31+
32+
/// <summary>
33+
/// Assumes input is sorted and finds value using BinarySearch.
34+
/// If value is not found, returns the logical index of 'value' in the sorted list i.e index of the first element greater than value.
35+
/// In case of duplicates it returns the index of the first one.
36+
/// It guarantees that items before the returned index are &lt; value, while those at and after the returned index are &gt;= value.
37+
/// </summary>
38+
public static int FindIndexSorted(int[] input, int min, int lim, int value)
39+
{
40+
return FindIndexSorted(input.AsSpan(), min, lim, value);
41+
}
42+
43+
/// <summary>
44+
/// Assumes input is sorted and finds value using BinarySearch.
45+
/// If value is not found, returns the logical index of 'value' in the sorted list i.e index of the first element greater than value.
46+
/// In case of duplicates it returns the index of the first one.
47+
/// It guarantees that items before the returned index are &lt; value, while those at and after the returned index are &gt;= value.
48+
/// </summary>
49+
public static int FindIndexSorted(ReadOnlySpan<int> input, int min, int lim, int value)
50+
{
51+
Debug.Assert(0 <= min & min <= lim & lim <= input.Length);
52+
53+
int minCur = min;
54+
int limCur = lim;
55+
while (minCur < limCur)
56+
{
57+
int mid = (int)(((uint)minCur + (uint)limCur) / 2);
58+
Debug.Assert(minCur <= mid & mid < limCur);
59+
60+
if (input[mid] >= value)
61+
limCur = mid;
62+
else
63+
minCur = mid + 1;
64+
65+
Debug.Assert(min <= minCur & minCur <= limCur & limCur <= lim);
66+
Debug.Assert(minCur == min || input[minCur - 1] < value);
67+
Debug.Assert(limCur == lim || input[limCur] >= value);
68+
}
69+
Debug.Assert(min <= minCur & minCur == limCur & limCur <= lim);
70+
Debug.Assert(minCur == min || input[minCur - 1] < value);
71+
Debug.Assert(limCur == lim || input[limCur] >= value);
72+
73+
return minCur;
74+
}
75+
76+
public static int EnsureSize<T>(ref T[] array, int min, int max, bool keepOld, out bool resized)
77+
{
78+
if (min > max)
79+
throw new ArgumentOutOfRangeException(nameof(max), "min must not exceed max");
80+
81+
// This code adapted from the private method EnsureCapacity code of List<T>.
82+
int size = ArrayUtils.Size(array);
83+
if (size >= min)
84+
{
85+
resized = false;
86+
return size;
87+
}
88+
89+
int newSize = size == 0 ? 4 : size * 2;
90+
// This constant taken from the internal code of system\array.cs of mscorlib.
91+
if ((uint)newSize > max)
92+
newSize = max;
93+
if (newSize < min)
94+
newSize = min;
95+
if (keepOld && size > 0)
96+
Array.Resize(ref array, newSize);
97+
else
98+
array = new T[newSize];
99+
100+
resized = true;
101+
return newSize;
102+
}
103+
}
104+
}

0 commit comments

Comments
 (0)