diff --git a/src/Microsoft.ML.Api/CodeGenerationUtils.cs b/src/Microsoft.ML.Api/CodeGenerationUtils.cs
index 7af0fb85ed..1c39f1b9dd 100644
--- a/src/Microsoft.ML.Api/CodeGenerationUtils.cs
+++ b/src/Microsoft.ML.Api/CodeGenerationUtils.cs
@@ -75,11 +75,11 @@ public static void AppendFieldDeclaration(CSharpCodeProvider codeProvider, Strin
target.Append(indent);
target.AppendFormat("public {0} {1}", generatedCsTypeName, fieldName);
- if (appendInitializer && colType.IsKnownSizeVector && !useVBuffer)
+ if (appendInitializer && colType is VectorType vecColType && vecColType.Size > 0 && !useVBuffer)
{
Contracts.Assert(generatedCsTypeName.EndsWith("[]"));
var csItemType = generatedCsTypeName.Substring(0, generatedCsTypeName.Length - 2);
- target.AppendFormat(" = new {0}[{1}]", csItemType, colType.VectorSize);
+ target.AppendFormat(" = new {0}[{1}]", csItemType, vecColType.Size);
}
target.AppendLine(";");
}
@@ -109,17 +109,13 @@ private static string GetBackingTypeName(ColumnType colType, bool useVBuffer, Li
{
Contracts.AssertValue(colType);
Contracts.AssertValue(attributes);
- if (colType.IsVector)
+ if (colType is VectorType vecColType)
{
- if (colType.IsKnownSizeVector)
+ if (vecColType.Size > 0)
{
// By default, arrays are assumed variable length, unless a [VectorType(dim1, dim2, ...)]
// attribute is applied to the fields.
- var vectorType = colType.AsVector;
- var dimensions = new int[vectorType.DimCount];
- for (int i = 0; i < dimensions.Length; i++)
- dimensions[i] = vectorType.GetDim(i);
- attributes.Add(string.Format("[VectorType({0})]", string.Join(", ", dimensions)));
+ attributes.Add(string.Format("[VectorType({0})]", string.Join(", ", vecColType.Dimensions)));
}
var itemType = GetBackingTypeName(colType.ItemType, false, attributes);
diff --git a/src/Microsoft.ML.Core/Data/ColumnType.cs b/src/Microsoft.ML.Core/Data/ColumnType.cs
index 69a6505c51..83b5caa27d 100644
--- a/src/Microsoft.ML.Core/Data/ColumnType.cs
+++ b/src/Microsoft.ML.Core/Data/ColumnType.cs
@@ -5,6 +5,7 @@
#pragma warning disable 420 // volatile with Interlocked.CompareExchange
using System;
+using System.Collections.Immutable;
using System.Linq;
using System.Reflection;
using System.Text;
@@ -14,41 +15,36 @@
namespace Microsoft.ML.Runtime.Data
{
///
- /// ColumnType is the abstract base class for all types in the IDataView type system.
+ /// This is the abstract base class for all types in the type system.
///
public abstract class ColumnType : IEquatable
{
- private readonly Type _rawType;
- private readonly DataKind _rawKind;
-
- // We cache these for speed and code size.
- private readonly bool _isPrimitive;
- private readonly bool _isVector;
- private readonly bool _isNumber;
- private readonly bool _isKey;
-
- // This private constructor sets all the _isXxx flags. It is invoked by other ctors.
+ // This private constructor sets all the IsXxx flags. It is invoked by other ctors.
private ColumnType()
{
- _isPrimitive = this is PrimitiveType;
- _isVector = this is VectorType;
- _isNumber = this is NumberType;
- _isKey = this is KeyType;
+ IsPrimitive = this is PrimitiveType;
+ IsVector = this is VectorType;
+ IsNumber = this is NumberType;
+ IsKey = this is KeyType;
}
- protected ColumnType(Type rawType)
+ ///
+ /// Constructor for extension types, which must be either or .
+ ///
+ private protected ColumnType(Type rawType)
: this()
{
Contracts.CheckValue(rawType, nameof(rawType));
- _rawType = rawType;
- _rawType.TryGetDataKind(out _rawKind);
+ RawType = rawType;
+ RawType.TryGetDataKind(out var rawKind);
+ RawKind = rawKind;
}
///
- /// Internal sub types can pass both the rawType and rawKind values. This asserts that they
- /// are consistent.
+ /// Internal sub types can pass both the and values.
+ /// This asserts that they are consistent.
///
- internal ColumnType(Type rawType, DataKind rawKind)
+ private protected ColumnType(Type rawType, DataKind rawKind)
: this()
{
Contracts.AssertValue(rawType);
@@ -57,42 +53,50 @@ internal ColumnType(Type rawType, DataKind rawKind)
rawType.TryGetDataKind(out tmp);
Contracts.Assert(tmp == rawKind);
#endif
- _rawType = rawType;
- _rawKind = rawKind;
+ RawType = rawType;
+ RawKind = rawKind;
}
///
- /// The raw System.Type for this ColumnType. Note that this is the raw representation type
- /// and NOT the complete information content of the ColumnType. Code should not assume that
- /// a RawType uniquely identifiers a ColumnType.
+ /// The raw for this . Note that this is the raw representation type
+ /// and not the complete information content of the . Code should not assume that
+ /// a uniquely identifiers a . For example, most practical instances of
+ /// and will have a of ,
+ /// but both are very different in the types of information conveyed in that number.
///
- public Type RawType { get { return _rawType; } }
+ public Type RawType { get; }
///
- /// The DataKind corresponding to RawType, if there is one (zero otherwise). It is equivalent
- /// to the result produced by DataKindExtensions.TryGetDataKind(RawType, out kind).
+ /// The corresponding to , if there is one (default otherwise).
+ /// It is equivalent to the result produced by .
+ /// For external code it would be preferable to operate over .
///
- public DataKind RawKind { get { return _rawKind; } }
+ [BestFriend]
+ internal DataKind RawKind { get; }
///
- /// Whether this is a primitive type.
+ /// Whether this is a primitive type. External code should use is .
///
- public bool IsPrimitive { get { return _isPrimitive; } }
+ [BestFriend]
+ internal bool IsPrimitive { get; }
///
- /// Equivalent to "this as PrimitiveType".
+ /// Equivalent to as .
///
- public PrimitiveType AsPrimitive { get { return _isPrimitive ? (PrimitiveType)this : null; } }
+ [BestFriend]
+ internal PrimitiveType AsPrimitive => IsPrimitive ? (PrimitiveType)this : null;
///
- /// Whether this type is a standard numeric type.
+ /// Whether this type is a standard numeric type. External code should use is .
///
- public bool IsNumber { get { return _isNumber; } }
+ [BestFriend]
+ internal bool IsNumber { get; }
///
- /// Whether this type is the standard text type.
+ /// Whether this type is the standard text type. External code should use is .
///
- public bool IsText
+ [BestFriend]
+ internal bool IsText
{
get
{
@@ -105,9 +109,10 @@ public bool IsText
}
///
- /// Whether this type is the standard boolean type.
+ /// Whether this type is the standard boolean type. External code should use is .
///
- public bool IsBool
+ [BestFriend]
+ internal bool IsBool
{
get
{
@@ -120,118 +125,93 @@ public bool IsBool
}
///
- /// Whether this type is the standard type.
- ///
- public bool IsTimeSpan
- {
- get
- {
- Contracts.Assert((this == TimeSpanType.Instance) == (this is TimeSpanType));
- return this is TimeSpanType;
- }
- }
-
- ///
- /// Whether this type is a .
- ///
- public bool IsDateTime
- {
- get
- {
- Contracts.Assert((this == DateTimeType.Instance) == (this is DateTimeType));
- return this is DateTimeType;
- }
- }
-
- ///
- /// Whether this type is a
- ///
- public bool IsDateTimeZone
- {
- get
- {
- Contracts.Assert((this == DateTimeOffsetType.Instance) == (this is DateTimeOffsetType));
- return this is DateTimeOffsetType;
- }
- }
-
- ///
- /// Whether this type is a standard scalar type completely determined by its RawType
- /// (not a KeyType or StructureType, etc).
+ /// Whether this type is a standard scalar type completely determined by its
+ /// (not a or , etc).
///
- public bool IsStandardScalar
- {
- get { return IsNumber || IsText || IsBool || IsTimeSpan || IsDateTime || IsDateTimeZone; }
- }
+ [BestFriend]
+ internal bool IsStandardScalar => IsNumber || IsText || IsBool ||
+ (this is TimeSpanType) || (this is DateTimeType) || (this is DateTimeOffsetType);
///
/// Whether this type is a key type, which implies that the order of values is not significant,
/// and arithmetic is non-sensical. A key type can define a cardinality.
+ /// External code should use is .
///
- public bool IsKey { get { return _isKey; } }
+ [BestFriend]
+ internal bool IsKey { get; }
///
- /// Equivalent to "this as KeyType".
+ /// Equivalent to as .
///
- public KeyType AsKey { get { return _isKey ? (KeyType)this : null; } }
+ [BestFriend]
+ internal KeyType AsKey => IsKey ? (KeyType)this : null;
///
- /// Zero return means either it's not a key type or the cardinality is unknown.
+ /// Zero return means either it's not a key type or the cardinality is unknown. External code should first
+ /// test whether this is of type , then if so get the property
+ /// from that.
///
- public int KeyCount { get { return KeyCountCore; } }
+ [BestFriend]
+ internal int KeyCount => KeyCountCore;
///
- /// The only sub-class that should override this is KeyType!
+ /// The only sub-class that should override this is .
///
- internal virtual int KeyCountCore { get { return 0; } }
+ private protected virtual int KeyCountCore => 0;
///
- /// Whether this is a vector type.
+ /// Whether this is a vector type. External code should just check directly against whether this type
+ /// is .
///
- public bool IsVector { get { return _isVector; } }
+ [BestFriend]
+ internal bool IsVector { get; }
///
- /// Equivalent to "this as VectorType".
+ /// Equivalent to as .
///
- public VectorType AsVector { get { return _isVector ? (VectorType)this : null; } }
+ [BestFriend]
+ internal VectorType AsVector => IsVector ? (VectorType)this : null;
///
- /// For non-vector types, this returns the column type itself (ie, return this).
+ /// For non-vector types, this returns the column type itself (i.e., return this).
///
- public ColumnType ItemType { get { return ItemTypeCore; } }
+ [BestFriend]
+ internal ColumnType ItemType => ItemTypeCore;
///
/// Whether this is a vector type with known size. Returns false for non-vector types.
- /// Equivalent to VectorSize > 0.
+ /// Equivalent to > 0.
///
- public bool IsKnownSizeVector { get { return VectorSize > 0; } }
+ [BestFriend]
+ internal bool IsKnownSizeVector => VectorSize > 0;
///
- /// Zero return means either it's not a vector or the size is unknown. Equivalent to
- /// IsVector ? ValueCount : 0 and to IsKnownSizeVector ? ValueCount : 0.
+ /// Zero return means either it's not a vector or the size is unknown.
///
- public int VectorSize { get { return VectorSizeCore; } }
+ [BestFriend]
+ internal int VectorSize => VectorSizeCore;
///
/// For non-vectors, this returns one. For unknown size vectors, it returns zero.
/// Equivalent to IsVector ? VectorSize : 1.
///
- public int ValueCount { get { return ValueCountCore; } }
+ [BestFriend]
+ internal int ValueCount => ValueCountCore;
///
/// The only sub-class that should override this is VectorType!
///
- internal virtual ColumnType ItemTypeCore { get { return this; } }
+ private protected virtual ColumnType ItemTypeCore => this;
///
- /// The only sub-class that should override this is VectorType!
+ /// The only sub-class that should override this is !
///
- internal virtual int VectorSizeCore { get { return 0; } }
+ private protected virtual int VectorSizeCore => 0;
///
/// The only sub-class that should override this is VectorType!
///
- internal virtual int ValueCountCore { get { return 1; } }
+ private protected virtual int ValueCountCore => 1;
// IEquatable interface recommends also to override base class implementations of
// Object.Equals(Object) and GetHashCode. In classes below where Equals(ColumnType other)
@@ -243,7 +223,8 @@ public bool IsStandardScalar
/// Equivalent to calling Equals(ColumnType) for non-vector types. For vector type,
/// returns true if current and other vector types have the same size and item type.
///
- public bool SameSizeAndItemType(ColumnType other)
+ [BestFriend]
+ internal bool SameSizeAndItemType(ColumnType other)
{
if (other == null)
return false;
@@ -271,7 +252,7 @@ protected StructuredType(Type rawType)
Contracts.Assert(!IsPrimitive);
}
- internal StructuredType(Type rawType, DataKind rawKind)
+ private protected StructuredType(Type rawType, DataKind rawKind)
: base(rawType, rawKind)
{
Contracts.Assert(!IsPrimitive);
@@ -289,17 +270,18 @@ protected PrimitiveType(Type rawType)
{
Contracts.Assert(IsPrimitive);
Contracts.CheckParam(!typeof(IDisposable).IsAssignableFrom(RawType), nameof(rawType),
- "A PrimitiveType cannot have a disposable RawType");
+ "A " + nameof(PrimitiveType) + " cannot have a disposable " + nameof(RawType));
}
- internal PrimitiveType(Type rawType, DataKind rawKind)
+ private protected PrimitiveType(Type rawType, DataKind rawKind)
: base(rawType, rawKind)
{
Contracts.Assert(IsPrimitive);
Contracts.Assert(!typeof(IDisposable).IsAssignableFrom(RawType));
}
- public static PrimitiveType FromKind(DataKind kind)
+ [BestFriend]
+ internal static PrimitiveType FromKind(DataKind kind)
{
if (kind == DataKind.TX)
return TextType.Instance;
@@ -344,10 +326,7 @@ public override bool Equals(ColumnType other)
return false;
}
- public override string ToString()
- {
- return "Text";
- }
+ public override string ToString() => "Text";
}
///
@@ -486,51 +465,50 @@ public static NumberType R8
}
}
- public static NumberType Float
- {
- get { return R4; }
- }
+ public static NumberType Float => R4;
- public static new NumberType FromKind(DataKind kind)
+ [BestFriend]
+ internal static new NumberType FromKind(DataKind kind)
{
switch (kind)
{
- case DataKind.I1:
- return I1;
- case DataKind.U1:
- return U1;
- case DataKind.I2:
- return I2;
- case DataKind.U2:
- return U2;
- case DataKind.I4:
- return I4;
- case DataKind.U4:
- return U4;
- case DataKind.I8:
- return I8;
- case DataKind.U8:
- return U8;
- case DataKind.R4:
- return R4;
- case DataKind.R8:
- return R8;
- case DataKind.UG:
- return UG;
+ case DataKind.I1:
+ return I1;
+ case DataKind.U1:
+ return U1;
+ case DataKind.I2:
+ return I2;
+ case DataKind.U2:
+ return U2;
+ case DataKind.I4:
+ return I4;
+ case DataKind.U4:
+ return U4;
+ case DataKind.I8:
+ return I8;
+ case DataKind.U8:
+ return U8;
+ case DataKind.R4:
+ return R4;
+ case DataKind.R8:
+ return R8;
+ case DataKind.UG:
+ return UG;
}
Contracts.Assert(false);
- throw Contracts.Except("Bad data kind in NumericType.FromKind: {0}", kind);
+ throw Contracts.Except($"Bad data kind in {nameof(NumberType)}.{nameof(FromKind)}: {kind}");
}
- public static NumberType FromType(Type type)
+ [BestFriend]
+ internal static NumberType FromType(Type type)
{
DataKind kind;
if (type.TryGetDataKind(out kind))
return FromKind(kind);
Contracts.Assert(false);
- throw Contracts.Except("Bad data kind in NumericType.FromKind: {0}", kind);
+ throw Contracts.Except($"Bad data kind in {nameof(NumberType)}.{nameof(FromType)}: {kind}", kind);
}
public override bool Equals(ColumnType other)
@@ -541,10 +519,7 @@ public override bool Equals(ColumnType other)
return false;
}
- public override string ToString()
- {
- return _name;
- }
+ public override string ToString() => _name;
}
///
@@ -608,10 +583,7 @@ public override bool Equals(ColumnType other)
return false;
}
- public override string ToString()
- {
- return "DateTime";
- }
+ public override string ToString() => "DateTime";
}
public sealed class DateTimeOffsetType : PrimitiveType
@@ -640,10 +612,7 @@ public override bool Equals(ColumnType other)
return false;
}
- public override string ToString()
- {
- return "DateTimeZone";
- }
+ public override string ToString() => "DateTimeZone";
}
///
@@ -675,10 +644,7 @@ public override bool Equals(ColumnType other)
return false;
}
- public override string ToString()
- {
- return "TimeSpan";
- }
+ public override string ToString() => "TimeSpan";
}
///
@@ -699,26 +665,45 @@ public override string ToString()
///
public sealed class KeyType : PrimitiveType
{
- private readonly bool _contiguous;
- private readonly ulong _min;
- // _count is only valid if _contiguous is true. Zero means unknown.
- private readonly int _count;
-
- public KeyType(DataKind kind, ulong min, int count, bool contiguous = true)
- : base(ToRawType(kind), kind)
+ private KeyType(Type type, DataKind kind, ulong min, int count, bool contiguous)
+ : base(type, kind)
{
+ Contracts.AssertValue(type);
+ Contracts.Assert(kind.ToType() == type);
+
Contracts.CheckParam(min >= 0, nameof(min));
- Contracts.CheckParam(count >= 0, nameof(count), "count for key type must be non-negative");
+ Contracts.CheckParam(count >= 0, nameof(count), "Must be non-negative.");
Contracts.CheckParam((ulong)count <= ulong.MaxValue - min, nameof(count));
Contracts.CheckParam((ulong)count <= kind.ToMaxInt(), nameof(count));
- Contracts.CheckParam(contiguous || count == 0, nameof(count), "count must be 0 for non-contiguous");
+ Contracts.CheckParam(contiguous || count == 0, nameof(count), "Must be 0 for non-contiguous");
- _contiguous = contiguous;
- _min = min;
- _count = count;
+ Contiguous = contiguous;
+ Min = min;
+ Count = count;
Contracts.Assert(IsKey);
}
+ public KeyType(Type type, ulong min, int count, bool contiguous = true)
+ : this(type, CheckRefRawType(type), min, count, contiguous)
+ {
+ }
+
+ [BestFriend]
+ internal KeyType(DataKind kind, ulong min, int count, bool contiguous = true)
+ : this(ToRawType(kind), kind, min, count, contiguous)
+ {
+ }
+
+ private static DataKind CheckRefRawType(Type type)
+ {
+ Contracts.CheckValue(type, nameof(type));
+ Contracts.CheckParam(IsValidDataType(type), nameof(type));
+ var result = type.TryGetDataKind(out var kind);
+ Contracts.Assert(result);
+ return kind;
+
+ }
+
private static Type ToRawType(DataKind kind)
{
Contracts.CheckParam(IsValidDataKind(kind), nameof(kind));
@@ -726,31 +711,43 @@ private static Type ToRawType(DataKind kind)
}
///
- /// Returns true iff the given DataKind is valid for a KeyType. The valid ones are
- /// U1, U2, U4, and U8, that is, the unsigned integer kinds.
+ /// Returns true iff the given DataKind is valid for a . The valid ones are
+ /// , , , and ,
+ /// that is, the unsigned integer kinds.
///
- public static bool IsValidDataKind(DataKind kind)
+ [BestFriend]
+ internal static bool IsValidDataKind(DataKind kind)
{
switch (kind)
{
- case DataKind.U1:
- case DataKind.U2:
- case DataKind.U4:
- case DataKind.U8:
- return true;
- default:
- return false;
+ case DataKind.U1:
+ case DataKind.U2:
+ case DataKind.U4:
+ case DataKind.U8:
+ return true;
+ default:
+ return false;
}
}
- internal override int KeyCountCore { get { return _count; } }
+ ///
+ /// Returns true iff the given type is valid for a . The valid ones are
+ /// , , , and , that is, the unsigned integer types.
+ ///
+ public static bool IsValidDataType(Type type)
+ {
+ Contracts.CheckValue(type, nameof(type));
+ return type == typeof(byte) || type == typeof(ushort) || type == typeof(uint) || type == typeof(ulong);
+ }
+
+ private protected override int KeyCountCore => Count;
///
/// This is the Min of the key type for display purposes and conversion to/from text. The values
/// actually stored always start at 1 (for the smallest legal value), with zero being reserved
/// for "not there"/"none". Typical Min values are 0 or 1, but can be any value >= 0.
///
- public ulong Min { get { return _min; } }
+ public ulong Min { get; }
///
/// If this key type has contiguous values and a known cardinality, Count is that cardinality.
@@ -760,9 +757,9 @@ public static bool IsValidDataKind(DataKind kind)
/// representation. Note that an id of 0 is used to represent the notion "none", which is
/// typically mapped to a vector of all zeros (of length Count).
///
- public int Count { get { return _count; } }
+ public int Count { get; }
- public bool Contiguous { get { return _contiguous; } }
+ public bool Contiguous { get; }
public override bool Equals(ColumnType other)
{
@@ -775,11 +772,11 @@ public override bool Equals(ColumnType other)
if (RawKind != tmp.RawKind)
return false;
Contracts.Assert(RawType == tmp.RawType);
- if (_contiguous != tmp._contiguous)
+ if (Contiguous != tmp.Contiguous)
return false;
- if (_min != tmp._min)
+ if (Min != tmp.Min)
return false;
- if (_count != tmp._count)
+ if (Count != tmp.Count)
return false;
return true;
}
@@ -791,17 +788,17 @@ public override bool Equals(object other)
public override int GetHashCode()
{
- return Hashing.CombinedHash(RawKind.GetHashCode(), _contiguous, _min, _count);
+ return Hashing.CombinedHash(RawKind.GetHashCode(), Contiguous, Min, Count);
}
public override string ToString()
{
- if (_count > 0)
- return string.Format("Key<{0}, {1}-{2}>", RawKind.GetString(), _min, _min + (ulong)_count - 1);
- if (_contiguous)
- return string.Format("Key<{0}, {1}-*>", RawKind.GetString(), _min);
+ if (Count > 0)
+ return string.Format("Key<{0}, {1}-{2}>", RawKind.GetString(), Min, Min + (ulong)Count - 1);
+ if (Contiguous)
+ return string.Format("Key<{0}, {1}-*>", RawKind.GetString(), Min);
// This is the non-contiguous case - simply show the Min.
- return string.Format("Key<{0}, Min:{1}>", RawKind.GetString(), _min);
+ return string.Format("Key<{0}, Min:{1}>", RawKind.GetString(), Min);
}
}
@@ -810,74 +807,75 @@ public override string ToString()
///
public sealed class VectorType : StructuredType
{
- private readonly PrimitiveType _itemType;
- private readonly int _size;
-
- // The _sizes are the cumulative products of the _dims. These may be null, meaning that
- // the information is naturally one dimensional.
- private readonly int[] _sizes;
- private readonly int[] _dims;
+ /// b
+ /// The dimensions. This will always have at least one item. All values will be non-negative.
+ /// As with , a zero value indicates that the vector type is considered to have
+ /// unknown length along that dimension.
+ ///
+ public ImmutableArray Dimensions { get; }
+ ///
+ /// Constructs a new single-dimensional vector type.
+ ///
+ /// The type of the items contained in the vector.
+ /// The size of the single dimension.
public VectorType(PrimitiveType itemType, int size = 0)
: base(GetRawType(itemType), 0)
{
Contracts.CheckParam(size >= 0, nameof(size));
- _itemType = itemType;
- _size = size;
+ ItemType = itemType;
+ Size = size;
+ Dimensions = ImmutableArray.Create(Size);
}
- public VectorType(PrimitiveType itemType, params int[] dims)
- : base(GetRawType(itemType), default(DataKind))
+ ///
+ /// Constructs a potentially multi-dimensional vector type.
+ ///
+ /// The type of the items contained in the vector.
+ /// The dimensions. Note that, like , must be non-empty, with all
+ /// non-negative values. Also, because is the product of , the result of
+ /// multiplying all these values together must not overflow .
+ public VectorType(PrimitiveType itemType, params int[] dimensions)
+ : base(GetRawType(itemType), default)
{
- Contracts.CheckParam(Utils.Size(dims) > 0, nameof(dims));
- Contracts.CheckParam(dims.All(d => d >= 0), nameof(dims));
-
- _itemType = itemType;
+ Contracts.CheckParam(Utils.Size(dimensions) > 0, nameof(dimensions));
+ Contracts.CheckParam(dimensions.All(d => d >= 0), nameof(dimensions));
- if (dims.Length == 1)
- _size = dims[0];
- else
- {
- _dims = new int[dims.Length];
- Array.Copy(dims, _dims, _dims.Length);
- _size = ComputeSizes(_dims, out _sizes);
- }
+ ItemType = itemType;
+ Dimensions = dimensions.ToImmutableArray();
+ Size = ComputeSize(Dimensions);
}
///
- /// Creates a VectorType whose dimensionality information is the given template's information.
+ /// Creates a whose dimensionality information is the given 's information.
///
- public VectorType(PrimitiveType itemType, VectorType template)
- : base(GetRawType(itemType), default(DataKind))
+ [BestFriend]
+ internal VectorType(PrimitiveType itemType, VectorType template)
+ : base(GetRawType(itemType), default)
{
Contracts.CheckValue(template, nameof(template));
- _itemType = itemType;
- _size = template._size;
- _sizes = template._sizes;
- _dims = template._dims;
+ ItemType = itemType;
+ Dimensions = template.Dimensions;
+ Size = template.Size;
}
///
- /// Creates a VectorType whose dimensionality information is the given template's information
- /// concatenated with the specified dims.
+ /// Creates a whose dimensionality information is the given 's information,
+ /// concatenated with the specified .
///
- public VectorType(PrimitiveType itemType, VectorType template, params int[] dims)
- : base(GetRawType(itemType), default(DataKind))
+ [BestFriend]
+ internal VectorType(PrimitiveType itemType, VectorType template, params int[] dims)
+ : base(GetRawType(itemType), default)
{
Contracts.CheckValue(template, nameof(template));
+ Contracts.CheckParam(Utils.Size(dims) > 0, nameof(dims));
+ Contracts.CheckParam(dims.All(d => d >= 0), nameof(dims));
- _itemType = itemType;
-
- if (template._dims == null)
- _dims = Utils.Concat(new int[] { template._size }, dims);
- else
- {
- Contracts.Assert(template._dims.Length >= 2);
- _dims = Utils.Concat(template._dims, dims);
- }
- _size = ComputeSizes(_dims, out _sizes);
+ ItemType = itemType;
+ Dimensions = template.Dimensions.AddRange(dims);
+ Size = ComputeSize(Dimensions);
}
private static Type GetRawType(PrimitiveType itemType)
@@ -886,60 +884,48 @@ private static Type GetRawType(PrimitiveType itemType)
return typeof(VBuffer<>).MakeGenericType(itemType.RawType);
}
- private static int ComputeSizes(int[] dims, out int[] sizes)
+ private static int ComputeSize(ImmutableArray dims)
{
- sizes = new int[dims.Length];
int size = 1;
- for (int i = dims.Length; --i >= 0; )
- size = sizes[i] = checked(size * dims[i]);
+ for (int i = 0; i < dims.Length; ++i)
+ size = checked(size * dims[i]);
return size;
}
- public int DimCount { get { return _dims != null ? _dims.Length : _size > 0 ? 1 : 0; } }
-
- public int GetDim(int idim)
- {
- if (_dims == null)
- {
- // That that if _size is zero, DimCount is zero, so this method is illegal
- // to call. That case is caught by Check(_size > 0).
- Contracts.Check(_size > 0);
- Contracts.Assert(DimCount == 1);
- Contracts.CheckParam(idim == 0, nameof(idim));
- return _size;
- }
-
- Contracts.CheckParam(0 <= idim && idim < _dims.Length, nameof(idim));
- return _dims[idim];
- }
+ ///
+ /// The type of the items stored as values in vectors of this type.
+ ///
+ public new PrimitiveType ItemType { get; }
- public new PrimitiveType ItemType { get { return _itemType; } }
+ ///
+ /// The size of the vector. A value of zero means it is a vector whose size is unknown.
+ /// A vector whose size is known should correspond to values that always have the same ,
+ /// whereas one whose size is known may have values whose varies from record to record.
+ /// Note that this is always the product of the elements in .
+ ///
+ public int Size { get; }
- internal override ColumnType ItemTypeCore { get { return _itemType; } }
+ private protected override ColumnType ItemTypeCore => ItemType;
- internal override int VectorSizeCore { get { return _size; } }
+ private protected override int VectorSizeCore => Size;
- internal override int ValueCountCore { get { return _size; } }
+ private protected override int ValueCountCore => Size;
public override bool Equals(ColumnType other)
{
if (other == this)
return true;
- var tmp = other.AsVector;
- if (tmp == null)
+ if (!(other is VectorType tmp))
return false;
- if (!_itemType.Equals(tmp._itemType))
+ if (!ItemType.Equals(tmp.ItemType))
return false;
- if (_size != tmp._size)
+ if (Size != tmp.Size)
return false;
- int count = Utils.Size(_dims);
- if (count != Utils.Size(tmp._dims))
+ if (Dimensions.Length != tmp.Dimensions.Length)
return false;
- if (count == 0)
- return true;
- for (int i = 0; i < count; i++)
+ for (int i = 0; i < Dimensions.Length; i++)
{
- if (_dims[i] != tmp._dims[i])
+ if (Dimensions[i] != tmp.Dimensions[i])
return false;
}
return true;
@@ -952,47 +938,26 @@ public override bool Equals(object other)
public override int GetHashCode()
{
- int hash = Hashing.CombinedHash(_itemType.GetHashCode(), _size);
- int count = Utils.Size(_dims);
- hash = Hashing.CombineHash(hash, count.GetHashCode());
- for (int i = 0; i < count; i++)
- hash = Hashing.CombineHash(hash, _dims[i].GetHashCode());
+ int hash = Hashing.CombinedHash(ItemType.GetHashCode(), Size);
+ hash = Hashing.CombineHash(hash, Dimensions.Length);
+ for (int i = 0; i < Dimensions.Length; i++)
+ hash = Hashing.CombineHash(hash, Dimensions[i].GetHashCode());
return hash;
}
- ///
- /// Returns true if current has the same item type of other, and the size
- /// of other is unknown or the current size is equal to the size of other.
- ///
- public bool IsSubtypeOf(VectorType other)
- {
- if (other == this)
- return true;
- if (other == null)
- return false;
-
- // REVIEW: Perhaps we should allow the case when _itemType is
- // a sub-type of other._itemType (in particular for key types)
- if (!_itemType.Equals(other._itemType))
- return false;
- if (other._size == 0 || _size == other._size)
- return true;
- return false;
- }
-
public override string ToString()
{
var sb = new StringBuilder();
- sb.Append("Vec<").Append(_itemType);
+ sb.Append("Vec<").Append(ItemType);
- if (_dims == null)
+ if (Dimensions.Length == 1)
{
- if (_size > 0)
- sb.Append(", ").Append(_size);
+ if (Size > 0)
+ sb.Append(", ").Append(Size);
}
else
{
- foreach (var dim in _dims)
+ foreach (var dim in Dimensions)
{
sb.Append(", ");
if (dim > 0)
diff --git a/src/Microsoft.ML.Core/Data/MetadataUtils.cs b/src/Microsoft.ML.Core/Data/MetadataUtils.cs
index 7ed3aecd9e..ee70bc732a 100644
--- a/src/Microsoft.ML.Core/Data/MetadataUtils.cs
+++ b/src/Microsoft.ML.Core/Data/MetadataUtils.cs
@@ -14,7 +14,7 @@
namespace Microsoft.ML.Runtime.Data
{
///
- /// Utilities for implementing and using the metadata API of ISchema.
+ /// Utilities for implementing and using the metadata API of .
///
public static class MetadataUtils
{
@@ -433,7 +433,7 @@ public static bool TryGetCategoricalFeatureIndices(Schema schema, int colIndex,
bool isValid = false;
categoricalFeatures = null;
- if (!schema.GetColumnType(colIndex).IsKnownSizeVector)
+ if (!(schema.GetColumnType(colIndex) is VectorType vecType && vecType.Size > 0))
return isValid;
var type = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.CategoricalSlotRanges, colIndex);
@@ -442,7 +442,7 @@ public static bool TryGetCategoricalFeatureIndices(Schema schema, int colIndex,
VBuffer catIndices = default(VBuffer);
schema.GetMetadata(MetadataUtils.Kinds.CategoricalSlotRanges, colIndex, ref catIndices);
VBufferUtils.Densify(ref catIndices);
- int columnSlotsCount = schema.GetColumnType(colIndex).AsVector.VectorSizeCore;
+ int columnSlotsCount = vecType.Size;
if (catIndices.Length > 0 && catIndices.Length % 2 == 0 && catIndices.Length <= columnSlotsCount * 2)
{
int previousEndIndex = -1;
diff --git a/src/Microsoft.ML.Core/Microsoft.ML.Core.csproj b/src/Microsoft.ML.Core/Microsoft.ML.Core.csproj
index c0f8558822..dea0558bb8 100644
--- a/src/Microsoft.ML.Core/Microsoft.ML.Core.csproj
+++ b/src/Microsoft.ML.Core/Microsoft.ML.Core.csproj
@@ -12,6 +12,7 @@
+
diff --git a/src/Microsoft.ML.Core/Properties/AssemblyInfo.cs b/src/Microsoft.ML.Core/Properties/AssemblyInfo.cs
index 453d5d5d44..0ef400e485 100644
--- a/src/Microsoft.ML.Core/Properties/AssemblyInfo.cs
+++ b/src/Microsoft.ML.Core/Properties/AssemblyInfo.cs
@@ -3,5 +3,39 @@
// See the LICENSE file in the project root for more information.
using System.Runtime.CompilerServices;
+using Microsoft.ML;
-[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TestFramework, PublicKey=002400000480000094000000060200000024000052534131000400000100010015c01ae1f50e8cc09ba9eac9147cf8fd9fce2cfe9f8dce4f7301c4132ca9fb50ce8cbf1df4dc18dd4d210e4345c744ecb3365ed327efdbc52603faa5e21daa11234c8c4a73e51f03bf192544581ebe107adee3a34928e39d04e524a9ce729d5090bfd7dad9d10c722c0def9ccc08ff0a03790e48bcd1f9b6c476063e1966a1c4")]
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TestFramework" + PublicKey.TestValue)]
+
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Legacy" + PublicKey.Value)]
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Maml" + PublicKey.Value)]
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.PipelineInference" + PublicKey.Value)]
+
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Data" + PublicKey.Value)]
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Api" + PublicKey.Value)]
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Ensemble" + PublicKey.Value)]
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.FastTree" + PublicKey.Value)]
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.HalLearners" + PublicKey.Value)]
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.KMeansClustering" + PublicKey.Value)]
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.LightGBM" + PublicKey.Value)]
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Onnx" + PublicKey.Value)]
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.OnnxTransform" + PublicKey.Value)]
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.PCA" + PublicKey.Value)]
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Recommender" + PublicKey.Value)]
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Runtime.ImageAnalytics" + PublicKey.Value)]
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Scoring" + PublicKey.Value)]
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.StandardLearners" + PublicKey.Value)]
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TensorFlow" + PublicKey.Value)]
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Transforms" + PublicKey.Value)]
+
+[assembly: WantsToBeBestFriends]
+
+namespace Microsoft.ML
+{
+ [BestFriend]
+ internal static class PublicKey
+ {
+ public const string Value = ", PublicKey=00240000048000009400000006020000002400005253413100040000010001004b86c4cb78549b34bab61a3b1800e23bfeb5b3ec390074041536a7e3cbd97f5f04cf0f857155a8928eaa29ebfd11cfbbad3ba70efea7bda3226c6a8d370a4cd303f714486b6ebc225985a638471e6ef571cc92a4613c00b8fa65d61ccee0cbe5f36330c9a01f4183559f1bef24cc2917c6d913e3a541333a1d05d9bed22b38cb";
+ public const string TestValue = ", PublicKey=002400000480000094000000060200000024000052534131000400000100010015c01ae1f50e8cc09ba9eac9147cf8fd9fce2cfe9f8dce4f7301c4132ca9fb50ce8cbf1df4dc18dd4d210e4345c744ecb3365ed327efdbc52603faa5e21daa11234c8c4a73e51f03bf192544581ebe107adee3a34928e39d04e524a9ce729d5090bfd7dad9d10c722c0def9ccc08ff0a03790e48bcd1f9b6c476063e1966a1c4";
+ }
+}
diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs
index 05a776cc7b..2362741bc8 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs
@@ -827,13 +827,13 @@ public VBufferCodec(CodecFactory factory, VectorType type, IValueCodec innerC
public int WriteParameterization(Stream stream)
{
int total = _factory.WriteCodec(stream, _innerCodec);
- int count = _type.DimCount;
+ int count = _type.Dimensions.Length;
total += sizeof(int) * (1 + count);
using (BinaryWriter writer = _factory.OpenBinaryWriter(stream))
{
writer.Write(count);
for (int i = 0; i < count; i++)
- writer.Write(_type.GetDim(i));
+ writer.Write(_type.Dimensions[i]);
}
return total;
}
@@ -1163,7 +1163,13 @@ private bool GetVBufferCodec(Stream definitionStream, out IValueCodec codec)
type = new VectorType(itemType, dims);
}
else
+ {
+ // In prior times, in the case where the VectorType was of single rank, *and* of unknown length,
+ // then the vector type would be considered to have a dimension count of 0, for some reason.
+ // This can no longer occur, but in the case where we read an older file we have to account for
+ // the fact that nothing may have been written.
type = new VectorType(itemType);
+ }
}
// Next create the vbuffer codec.
Type codecType = typeof(VBufferCodec<>).MakeGenericType(itemType.RawType);
diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs
index 8e83f01ac1..db84057b79 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs
@@ -97,17 +97,17 @@ protected ValueWriterBase(PrimitiveType type, int source, char sep)
ValueMapper, StringBuilder> c = MapText;
Conv = (ValueMapper)(Delegate)c;
}
- else if (type.IsTimeSpan)
+ else if (type is TimeSpanType)
{
ValueMapper c = MapTimeSpan;
Conv = (ValueMapper)(Delegate)c;
}
- else if (type.IsDateTime)
+ else if (type is DateTimeType)
{
ValueMapper c = MapDateTime;
Conv = (ValueMapper)(Delegate)c;
}
- else if (type.IsDateTimeZone)
+ else if (type is DateTimeOffsetType)
{
ValueMapper c = MapDateTimeZone;
Conv = (ValueMapper)(Delegate)c;
diff --git a/src/Microsoft.ML.Data/Transforms/NAFilter.cs b/src/Microsoft.ML.Data/Transforms/NAFilter.cs
index 05ffea70b5..cd49e30b78 100644
--- a/src/Microsoft.ML.Data/Transforms/NAFilter.cs
+++ b/src/Microsoft.ML.Data/Transforms/NAFilter.cs
@@ -113,7 +113,7 @@ public NAFilter(IHostEnvironment env, Arguments args, IDataView input)
var type = schema.GetColumnType(index);
if (!TestType(type))
- throw Host.ExceptUserArg(nameof(args.Column), "Column '{0}' does not have compatible numeric type", src);
+ throw Host.ExceptUserArg(nameof(args.Column), $"Column '{src}' has type {type} which does not support missing values, so we cannot filter on them", src);
_infos[i] = new ColInfo(index, type);
_srcIndexToInfoIndex.Add(index, i);
@@ -149,7 +149,7 @@ public NAFilter(IHost host, ModelLoadContext ctx, IDataView input)
var type = schema.GetColumnType(index);
if (!TestType(type))
- throw Host.Except("Column '{0}' does not have compatible numeric type", src);
+ throw Host.Except($"Column '{src}' has type {type} which does not support missing values, so we cannot filter on them", src);
_infos[i] = new ColInfo(index, type);
_srcIndexToInfoIndex.Add(index, i);
@@ -187,32 +187,12 @@ private static bool TestType(ColumnType type)
{
Contracts.AssertValue(type);
- var itemType = type.ItemType;
- if (itemType.IsNumber)
- {
- switch (itemType.RawKind)
- {
- case DataKind.I1:
- case DataKind.I2:
- case DataKind.I4:
- case DataKind.I8:
- case DataKind.R4:
- case DataKind.R8:
- return true;
- }
- return false;
- }
- if (itemType.IsText)
- return true;
- if (itemType.IsBool)
- return true;
- if (itemType.IsKey)
- return true;
- if (itemType.IsTimeSpan)
+ var itemType = (type as VectorType)?.ItemType ?? type;
+ if (itemType == NumberType.R4)
return true;
- if (itemType.IsDateTime)
+ if (itemType == NumberType.R8)
return true;
- if (itemType.IsDateTimeZone)
+ if (itemType is KeyType)
return true;
return false;
}
@@ -287,15 +267,15 @@ public static Value Create(RowCursor cursor, ColInfo info)
Contracts.AssertValue(info);
MethodInfo meth;
- if (!info.Type.IsVector)
+ if (info.Type is VectorType vecType)
{
- Func> d = CreateOne;
- meth = d.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(info.Type.RawType);
+ Func> d = CreateVec;
+ meth = d.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(vecType.ItemType.RawType);
}
else
{
- Func> d = CreateVec;
- meth = d.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(info.Type.ItemType.RawType);
+ Func> d = CreateOne;
+ meth = d.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(info.Type.RawType);
}
return (Value)meth.Invoke(null, new object[] { cursor, info });
}
@@ -304,7 +284,7 @@ private static ValueOne CreateOne(RowCursor cursor, ColInfo info)
{
Contracts.AssertValue(cursor);
Contracts.AssertValue(info);
- Contracts.Assert(!info.Type.IsVector);
+ Contracts.Assert(!(info.Type is VectorType));
Contracts.Assert(info.Type.RawType == typeof(T));
var getSrc = cursor.Input.GetGetter(info.Index);
@@ -316,8 +296,8 @@ private static ValueVec CreateVec(RowCursor cursor, ColInfo info)
{
Contracts.AssertValue(cursor);
Contracts.AssertValue(info);
- Contracts.Assert(info.Type.IsVector);
- Contracts.Assert(info.Type.ItemType.RawType == typeof(T));
+ Contracts.Assert(info.Type is VectorType);
+ Contracts.Assert(info.Type.RawType == typeof(VBuffer));
var getSrc = cursor.Input.GetGetter>(info.Index);
var hasBad = Runtime.Data.Conversion.Conversions.Instance.GetHasMissingPredicate((VectorType)info.Type);
diff --git a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs
index 1bb0db3b4c..d19cf6f11b 100644
--- a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs
+++ b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs
@@ -442,15 +442,16 @@ protected override Delegate MakeGetter(IRow input, int iinfo, out Action dispose
private ValueGetter> GetGetterCore(IRow input, int iinfo, out Action disposer)
{
var type = _types[iinfo];
- Contracts.Assert(type.DimCount == 3);
+ var dims = type.Dimensions;
+ Contracts.Assert(dims.Length == 3);
var ex = _parent._columns[iinfo];
- int planes = ex.Interleave ? type.GetDim(2) : type.GetDim(0);
- int height = ex.Interleave ? type.GetDim(0) : type.GetDim(1);
- int width = ex.Interleave ? type.GetDim(1) : type.GetDim(2);
+ int planes = ex.Interleave ? dims[2] : dims[0];
+ int height = ex.Interleave ? dims[0] : dims[1];
+ int width = ex.Interleave ? dims[1] : dims[2];
- int size = type.ValueCount;
+ int size = type.Size;
Contracts.Assert(size > 0);
Contracts.Assert(size == planes * height * width);
int cpix = height * width;
diff --git a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs
index c325f87a4d..c0bcb5318d 100644
--- a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs
+++ b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs
@@ -922,8 +922,8 @@ public static void Initialize(IHost host, int numThreads, IChannel ch, FeatureFl
else
{
ArrayDataViewBuilder arrDv = new ArrayDataViewBuilder(host);
- arrDv.AddColumn(DefaultColumnNames.Features, PrimitiveType.FromKind(DataKind.R4), clusters);
- arrDv.AddColumn(DefaultColumnNames.Weight, PrimitiveType.FromKind(DataKind.R4), totalWeights);
+ arrDv.AddColumn(DefaultColumnNames.Features, NumberType.R4, clusters);
+ arrDv.AddColumn(DefaultColumnNames.Weight, NumberType.R4, totalWeights);
var subDataViewCursorFactory = new FeatureFloatVectorCursor.Factory(
new RoleMappedData(arrDv.GetDataView(), null, DefaultColumnNames.Features, weight: DefaultColumnNames.Weight), CursOpt.Weight | CursOpt.Features);
long discard1;
diff --git a/src/Microsoft.ML.Onnx/OnnxUtils.cs b/src/Microsoft.ML.Onnx/OnnxUtils.cs
index 83623ab0ba..e7c06ee53a 100644
--- a/src/Microsoft.ML.Onnx/OnnxUtils.cs
+++ b/src/Microsoft.ML.Onnx/OnnxUtils.cs
@@ -368,8 +368,8 @@ public static ModelArgs GetModelArgs(ColumnType type, string colName,
else if (type.ValueCount > 1)
{
var vec = type.AsVector;
- for (int i = 0; i < vec.DimCount; i++)
- dimsLocal.Add(vec.GetDim(i));
+ for (int i = 0; i < vec.Dimensions.Length; i++)
+ dimsLocal.Add(vec.Dimensions[i]);
}
}
//batch size.
diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
index bacc9945b7..fb981a1bf2 100644
--- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
+++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
@@ -828,20 +828,23 @@ public Mapper(IHostEnvironment env, TensorFlowTransform parent, ISchema inputSch
throw _host.Except($"Column {_parent.Inputs[i]} doesn't exist");
var type = inputSchema.GetColumnType(_inputColIndices[i]);
- if (type.IsVector && type.VectorSize == 0)
- throw _host.Except($"Variable length input columns not supported");
+ if (type is VectorType vecType && vecType.Size == 0)
+ throw _host.Except("Variable length input columns not supported");
- _isInputVector[i] = type.IsVector;
+ _isInputVector[i] = type is VectorType;
+ if (!_isInputVector[i]) // Temporary pending fix of issue #1542. In its current state, the below code would fail anyway with a naked exception if this check was not here.
+ throw _host.Except("Non-vector columns not supported");
+ vecType = (VectorType)type;
var expectedType = TensorFlowUtils.Tf2MlNetType(_parent.TFInputTypes[i]);
if (type.ItemType != expectedType)
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i], expectedType.ToString(), type.ToString());
var originalShape = _parent.TFInputShapes[i];
var shape = originalShape.ToIntArray();
- var colTypeDims = Enumerable.Range(0, type.AsVector.DimCount + 1).Select(d => d == 0 ? 1 : (long)type.AsVector.GetDim(d - 1)).ToArray();
+ var colTypeDims = vecType.Dimensions.Prepend(1).Select(dim => (long)dim).ToArray();
if (shape == null)
_fullySpecifiedShapes[i] = new TFShape(colTypeDims);
- else if (type.AsVector.DimCount == 1)
+ else if (vecType.Dimensions.Length == 1)
{
// If the column is one dimension we make sure that the total size of the TF shape matches.
// Compute the total size of the known dimensions of the shape.
diff --git a/src/Microsoft.ML.Transforms/NAReplaceTransform.cs b/src/Microsoft.ML.Transforms/NAReplaceTransform.cs
index b0c9df8406..6e59b928bf 100644
--- a/src/Microsoft.ML.Transforms/NAReplaceTransform.cs
+++ b/src/Microsoft.ML.Transforms/NAReplaceTransform.cs
@@ -153,8 +153,7 @@ internal static string TestType(ColumnType type)
{
// Item type must have an NA value that exists and is not equal to its default value.
Func func = TestType;
- var meth = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(type.ItemType.RawType);
- return (string)meth.Invoke(null, new object[] { type.ItemType });
+ return Utils.MarshalInvoke(func, type.ItemType.RawType, type.ItemType);
}
private static string TestType(ColumnType type)
@@ -339,7 +338,7 @@ private void GetReplacementValues(IDataView input, ColumnInfo[] columns, out obj
case ReplacementKind.Mean:
case ReplacementKind.Minimum:
case ReplacementKind.Maximum:
- if (!type.ItemType.IsNumber && !type.ItemType.IsTimeSpan && !type.ItemType.IsDateTime)
+ if (!type.ItemType.IsNumber)
throw Host.Except("Cannot perform mean imputations on non-numeric '{0}'", type.ItemType);
imputationModes[iinfo] = kind;
Utils.Add(ref columnsToImpute, iinfo);
diff --git a/test/Microsoft.ML.Benchmarks/HashBench.cs b/test/Microsoft.ML.Benchmarks/HashBench.cs
index 53e5789a13..74a056ea77 100644
--- a/test/Microsoft.ML.Benchmarks/HashBench.cs
+++ b/test/Microsoft.ML.Benchmarks/HashBench.cs
@@ -46,7 +46,7 @@ private void InitMap(T val, ColumnType type, int hashBits = 20)
var mapper = xf.GetRowToRowMapper(inRow.Schema);
mapper.Schema.TryGetColumnIndex("Bar", out int outCol);
var outRow = mapper.GetRow(inRow, c => c == outCol, out var _);
- if (type.IsVector)
+ if (type is VectorType)
_vecGetter = outRow.GetGetter>(outCol);
else
_getter = outRow.GetGetter(outCol);
@@ -123,7 +123,7 @@ public void HashScalarDouble()
[GlobalSetup(Target = nameof(HashScalarKey))]
public void SetupHashScalarKey()
{
- InitMap(6u, new KeyType(DataKind.U4, 0, 100));
+ InitMap(6u, new KeyType(typeof(uint), 0, 100));
}
[Benchmark]
@@ -174,7 +174,7 @@ public void HashVectorDouble()
[GlobalSetup(Target = nameof(HashVectorKey))]
public void SetupHashVectorKey()
{
- InitDenseVecMap(new[] { 1u, 2u, 0u, 4u, 5u }, new KeyType(DataKind.U4, 0, 100));
+ InitDenseVecMap(new[] { 1u, 2u, 0u, 4u, 5u }, new KeyType(typeof(uint), 0, 100));
}
[Benchmark]
diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/ColumnTypes.cs b/test/Microsoft.ML.Core.Tests/UnitTests/ColumnTypes.cs
index f3cfa8c270..ac0bd8c187 100644
--- a/test/Microsoft.ML.Core.Tests/UnitTests/ColumnTypes.cs
+++ b/test/Microsoft.ML.Core.Tests/UnitTests/ColumnTypes.cs
@@ -17,12 +17,15 @@ public void TestEqualAndGetHashCode()
{
var dict = new Dictionary();
// add PrimitiveTypes, KeyType & corresponding VectorTypes
- PrimitiveType tmp;
VectorType tmp1, tmp2;
- foreach (var kind in (DataKind[])Enum.GetValues(typeof(DataKind)))
+ var types = new PrimitiveType[] { NumberType.I1, NumberType.I2, NumberType.I4, NumberType.I8,
+ NumberType.U1, NumberType.U2, NumberType.U4, NumberType.U8, NumberType.UG,
+ TextType.Instance, BoolType.Instance, DateTimeType.Instance, DateTimeOffsetType.Instance, TimeSpanType.Instance };
+
+ foreach (var type in types)
{
- tmp = PrimitiveType.FromKind(kind);
- if(dict.ContainsKey(tmp) && dict[tmp] != tmp.ToString())
+ var tmp = type;
+ if (dict.ContainsKey(tmp) && dict[tmp] != tmp.ToString())
Assert.True(false, dict[tmp] + " and " + tmp.ToString() + " are duplicates.");
dict[tmp] = tmp.ToString();
for (int size = 0; size < 5; size++)
@@ -41,13 +44,14 @@ public void TestEqualAndGetHashCode()
}
// KeyType & Vector
- if (!KeyType.IsValidDataKind(kind))
+ var rawType = tmp.RawType;
+ if (!KeyType.IsValidDataType(rawType))
continue;
for (ulong min = 0; min < 5; min++)
{
for (var count = 0; count < 5; count++)
{
- tmp = new KeyType(kind, min, count);
+ tmp = new KeyType(rawType, min, count);
if (dict.ContainsKey(tmp) && dict[tmp] != tmp.ToString())
Assert.True(false, dict[tmp] + " and " + tmp.ToString() + " are duplicates.");
dict[tmp] = tmp.ToString();
@@ -66,7 +70,7 @@ public void TestEqualAndGetHashCode()
}
}
}
- tmp = new KeyType(kind, min, 0, false);
+ tmp = new KeyType(rawType, min, 0, false);
if (dict.ContainsKey(tmp) && dict[tmp] != tmp.ToString())
Assert.True(false, dict[tmp] + " and " + tmp.ToString() + " are duplicates.");
dict[tmp] = tmp.ToString();
diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/CoreBaseTestClass.cs b/test/Microsoft.ML.Core.Tests/UnitTests/CoreBaseTestClass.cs
index 34eb063b5b..b166007f0a 100644
--- a/test/Microsoft.ML.Core.Tests/UnitTests/CoreBaseTestClass.cs
+++ b/test/Microsoft.ML.Core.Tests/UnitTests/CoreBaseTestClass.cs
@@ -29,7 +29,9 @@ protected bool EqualTypes(ColumnType type1, ColumnType type2, bool exactTypes)
Contracts.AssertValue(type1);
Contracts.AssertValue(type2);
- return exactTypes ? type1.Equals(type2) : type1.SameSizeAndItemType(type2);
+ if (type1.Equals(type2))
+ return true;
+ return !exactTypes && type1 is VectorType vt1 && type2 is VectorType vt2 && vt1.ItemType.Equals(vt2.ItemType) && vt1.Size == vt2.Size;
}
protected Func GetIdComparer(IRow r1, IRow r2, out ValueGetter idGetter)
@@ -148,88 +150,93 @@ protected bool CompareVec(in VBuffer v1, in VBuffer v2, int size, Func<
}
protected Func GetColumnComparer(IRow r1, IRow r2, int col, ColumnType type, bool exactDoubles)
{
- if (!type.IsVector)
+ if (type is VectorType vecType)
{
- switch (type.RawKind)
+ int size = vecType.Size;
+ Contracts.Assert(size >= 0);
+ var result = vecType.ItemType.RawType.TryGetDataKind(out var kind);
+ Contracts.Assert(result);
+
+ switch (kind)
{
case DataKind.I1:
- return GetComparerOne(r1, r2, col, (x, y) => x == y);
+ return GetComparerVec(r1, r2, col, size, (x, y) => x == y);
case DataKind.U1:
- return GetComparerOne(r1, r2, col, (x, y) => x == y);
+ return GetComparerVec(r1, r2, col, size, (x, y) => x == y);
case DataKind.I2:
- return GetComparerOne(r1, r2, col, (x, y) => x == y);
+ return GetComparerVec(r1, r2, col, size, (x, y) => x == y);
case DataKind.U2:
- return GetComparerOne(r1, r2, col, (x, y) => x == y);
+ return GetComparerVec(r1, r2, col, size, (x, y) => x == y);
case DataKind.I4:
- return GetComparerOne(r1, r2, col, (x, y) => x == y);
+ return GetComparerVec(r1, r2, col, size, (x, y) => x == y);
case DataKind.U4:
- return GetComparerOne(r1, r2, col, (x, y) => x == y);
+ return GetComparerVec(r1, r2, col, size, (x, y) => x == y);
case DataKind.I8:
- return GetComparerOne(r1, r2, col, (x, y) => x == y);
+ return GetComparerVec(r1, r2, col, size, (x, y) => x == y);
case DataKind.U8:
- return GetComparerOne(r1, r2, col, (x, y) => x == y);
+ return GetComparerVec(r1, r2, col, size, (x, y) => x == y);
case DataKind.R4:
- return GetComparerOne(r1, r2, col, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y));
+ return GetComparerVec(r1, r2, col, size, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y));
case DataKind.R8:
if (exactDoubles)
- return GetComparerOne(r1, r2, col, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y));
+ return GetComparerVec(r1, r2, col, size, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y));
else
- return GetComparerOne(r1, r2, col, EqualWithEps);
+ return GetComparerVec(r1, r2, col, size, EqualWithEps);
case DataKind.Text:
- return GetComparerOne>(r1, r2, col, (a, b) => a.Span.SequenceEqual(b.Span));
+ return GetComparerVec>(r1, r2, col, size, (a, b) => a.Span.SequenceEqual(b.Span));
case DataKind.Bool:
- return GetComparerOne(r1, r2, col, (x, y) => x == y);
+ return GetComparerVec(r1, r2, col, size, (x, y) => x == y);
case DataKind.TimeSpan:
- return GetComparerOne(r1, r2, col, (x, y) => x.Ticks == y.Ticks);
+ return GetComparerVec(r1, r2, col, size, (x, y) => x.Ticks == y.Ticks);
case DataKind.DT:
- return GetComparerOne(r1, r2, col, (x, y) => x.Ticks == y.Ticks);
+ return GetComparerVec(r1, r2, col, size, (x, y) => x.Ticks == y.Ticks);
case DataKind.DZ:
- return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y));
+ return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y));
case DataKind.UG:
- return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y));
+ return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y));
}
}
else
{
- int size = type.VectorSize;
- Contracts.Assert(size >= 0);
- switch (type.ItemType.RawKind)
+ var result = type.RawType.TryGetDataKind(out var kind);
+ Contracts.Assert(result);
+ switch (kind)
{
case DataKind.I1:
- return GetComparerVec(r1, r2, col, size, (x, y) => x == y);
+ return GetComparerOne(r1, r2, col, (x, y) => x == y);
case DataKind.U1:
- return GetComparerVec(r1, r2, col, size, (x, y) => x == y);
+ return GetComparerOne(r1, r2, col, (x, y) => x == y);
case DataKind.I2:
- return GetComparerVec(r1, r2, col, size, (x, y) => x == y);
+ return GetComparerOne(r1, r2, col, (x, y) => x == y);
case DataKind.U2:
- return GetComparerVec(r1, r2, col, size, (x, y) => x == y);
+ return GetComparerOne(r1, r2, col, (x, y) => x == y);
case DataKind.I4:
- return GetComparerVec(r1, r2, col, size, (x, y) => x == y);
+ return GetComparerOne(r1, r2, col, (x, y) => x == y);
case DataKind.U4:
- return GetComparerVec(r1, r2, col, size, (x, y) => x == y);
+ return GetComparerOne(r1, r2, col, (x, y) => x == y);
case DataKind.I8:
- return GetComparerVec(r1, r2, col, size, (x, y) => x == y);
+ return GetComparerOne(r1, r2, col, (x, y) => x == y);
case DataKind.U8:
- return GetComparerVec(r1, r2, col, size, (x, y) => x == y);
+ return GetComparerOne(r1, r2, col, (x, y) => x == y);
case DataKind.R4:
- return GetComparerVec(r1, r2, col, size, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y));
+ return GetComparerOne(r1, r2, col, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y));
case DataKind.R8:
if (exactDoubles)
- return GetComparerVec(r1, r2, col, size, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y));
+ return GetComparerOne(r1, r2, col, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y));
else
- return GetComparerVec(r1, r2, col, size, EqualWithEps);
+ return GetComparerOne(r1, r2, col, EqualWithEps);
case DataKind.Text:
- return GetComparerVec>(r1, r2, col, size, (a,b) => a.Span.SequenceEqual(b.Span));
+ return GetComparerOne>(r1, r2, col, (a, b) => a.Span.SequenceEqual(b.Span));
case DataKind.Bool:
- return GetComparerVec(r1, r2, col, size, (x, y) => x == y);
+ return GetComparerOne(r1, r2, col, (x, y) => x == y);
case DataKind.TimeSpan:
- return GetComparerVec(r1, r2, col, size, (x, y) => x.Ticks == y.Ticks);
+ return GetComparerOne(r1, r2, col, (x, y) => x.Ticks == y.Ticks);
case DataKind.DT:
- return GetComparerVec(r1, r2, col, size, (x, y) => x.Ticks == y.Ticks);
+ return GetComparerOne(r1, r2, col, (x, y) => x.Ticks == y.Ticks);
case DataKind.DZ:
- return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y));
+ return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y));
case DataKind.UG:
- return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y));
+ return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y));
}
}
@@ -318,7 +325,7 @@ protected bool CheckSameValues(IRowCursor curs1, IRowCursor curs2, bool exactTyp
var type2 = curs2.Schema.GetColumnType(col);
if (!EqualTypes(type1, type2, exactTypes))
{
- Fail("Different types");
+ Fail($"Different types {type1} and {type2}");
return Failed();
}
comps[col] = GetColumnComparer(curs1, curs2, col, type1, exactDoubles);
diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs b/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs
index 3f1408efd9..385e6cd527 100644
--- a/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs
+++ b/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs
@@ -21,10 +21,10 @@ public DataTypesTest(ITestOutputHelper helper)
[Fact]
public void R4ToSBtoR4()
{
- var r4ToSB = Conversions.Instance.GetStringConversion(NumberType.FromKind(DataKind.R4));
+ var r4ToSB = Conversions.Instance.GetStringConversion(NumberType.R4);
var txToR4 = Conversions.Instance.GetStandardConversion< ReadOnlyMemory, float>(
- TextType.Instance, NumberType.FromKind(DataKind.R4), out bool identity2);
+ TextType.Instance, NumberType.R4, out bool identity2);
Assert.NotNull(r4ToSB);
Assert.NotNull(txToR4);
@@ -45,10 +45,10 @@ public void R4ToSBtoR4()
[Fact]
public void R8ToSBtoR8()
{
- var r8ToSB = Conversions.Instance.GetStringConversion(NumberType.FromKind(DataKind.R8));
+ var r8ToSB = Conversions.Instance.GetStringConversion(NumberType.R8);
var txToR8 = Conversions.Instance.GetStandardConversion, double>(
- TextType.Instance, NumberType.FromKind(DataKind.R8), out bool identity2);
+ TextType.Instance, NumberType.R8, out bool identity2);
Assert.NotNull(r8ToSB);
Assert.NotNull(txToR8);
@@ -69,7 +69,7 @@ public void R8ToSBtoR8()
[Fact]
public void TXToSByte()
{
- var mapper = GetMapper, sbyte>();
+ var mapper = GetMapper, sbyte>(NumberType.I1);
Assert.NotNull(mapper);
@@ -109,7 +109,7 @@ public void TXToSByte()
[Fact]
public void TXToShort()
{
- var mapper = GetMapper, short>();
+ var mapper = GetMapper, short>(NumberType.I2);
Assert.NotNull(mapper);
@@ -149,7 +149,7 @@ public void TXToShort()
[Fact]
public void TXToInt()
{
- var mapper = GetMapper, int>();
+ var mapper = GetMapper, int>(NumberType.I4);
Assert.NotNull(mapper);
@@ -189,7 +189,7 @@ public void TXToInt()
[Fact]
public void TXToLong()
{
- var mapper = GetMapper, long>();
+ var mapper = GetMapper, long>(NumberType.I8);
Assert.NotNull(mapper);
@@ -226,14 +226,12 @@ public void TXToLong()
Assert.Equal(default, dst);
}
- public ValueMapper GetMapper()
+ private static ValueMapper GetMapper(ColumnType dstType)
{
- Assert.True(typeof(TDst).TryGetDataKind(out DataKind dstDataKind));
+ Assert.True(typeof(TDst) == dstType.RawType);
return Conversions.Instance.GetStandardConversion(
- TextType.Instance, NumberType.FromKind(dstDataKind), out bool identity);
+ TextType.Instance, dstType, out bool identity);
}
}
}
-
-
diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs
index af9fedfde3..5a38ecb6d0 100644
--- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs
+++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs
@@ -504,7 +504,7 @@ public void TestCrossValidationMacroWithMultiClass()
b = schema.TryGetColumnIndex("Fold Index", out foldCol);
Assert.True(b);
var type = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, countCol);
- Assert.True(type != null && type.ItemType.IsText && type.VectorSize == 10);
+ Assert.True(type is VectorType vecType && vecType.ItemType is TextType && vecType.Size == 10);
var slotNames = default(VBuffer>);
schema.GetMetadata(MetadataUtils.Kinds.SlotNames, countCol, ref slotNames);
Assert.True(slotNames.Values.Select((s, i) => ReadOnlyMemoryUtils.EqualsStr(i.ToString(), s)).All(x => x));
@@ -992,7 +992,7 @@ public void TestTensorFlowEntryPoint()
var schema = data.Schema;
Assert.Equal(3, schema.ColumnCount);
Assert.Equal("Softmax", schema.GetColumnName(2));
- Assert.Equal(10, schema.GetColumnType(2).VectorSize);
+ Assert.Equal(10, (schema.GetColumnType(2) as VectorType)?.Size);
}
}
}
diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
index 523af789f4..1a5b46af14 100644
--- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
+++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
@@ -851,7 +851,7 @@ public void EntryPointPipelineEnsemble()
var hasScoreCol = binaryScored.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out int scoreIndex);
Assert.True(hasScoreCol, "Data scored with binary ensemble does not have a score column");
var type = binaryScored.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.ScoreColumnKind, scoreIndex);
- Assert.True(type != null && type.IsText, "Binary ensemble scored data does not have correct type of metadata.");
+ Assert.True(type is TextType, "Binary ensemble scored data does not have correct type of metadata.");
var kind = default(ReadOnlyMemory);
binaryScored.Schema.GetMetadata(MetadataUtils.Kinds.ScoreColumnKind, scoreIndex, ref kind);
Assert.True(ReadOnlyMemoryUtils.EqualsStr(MetadataUtils.Const.ScoreColumnKind.BinaryClassification, kind),
@@ -860,7 +860,7 @@ public void EntryPointPipelineEnsemble()
hasScoreCol = regressionScored.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out scoreIndex);
Assert.True(hasScoreCol, "Data scored with regression ensemble does not have a score column");
type = regressionScored.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.ScoreColumnKind, scoreIndex);
- Assert.True(type != null && type.IsText, "Regression ensemble scored data does not have correct type of metadata.");
+ Assert.True(type is TextType, "Regression ensemble scored data does not have correct type of metadata.");
regressionScored.Schema.GetMetadata(MetadataUtils.Kinds.ScoreColumnKind, scoreIndex, ref kind);
Assert.True(ReadOnlyMemoryUtils.EqualsStr(MetadataUtils.Const.ScoreColumnKind.Regression, kind),
$"Regression ensemble scored data column type should be '{MetadataUtils.Const.ScoreColumnKind.Regression}', but is instead '{kind}'");
@@ -868,7 +868,7 @@ public void EntryPointPipelineEnsemble()
hasScoreCol = anomalyScored.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out scoreIndex);
Assert.True(hasScoreCol, "Data scored with anomaly detection ensemble does not have a score column");
type = anomalyScored.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.ScoreColumnKind, scoreIndex);
- Assert.True(type != null && type.IsText, "Anomaly detection ensemble scored data does not have correct type of metadata.");
+ Assert.True(type is TextType, "Anomaly detection ensemble scored data does not have correct type of metadata.");
anomalyScored.Schema.GetMetadata(MetadataUtils.Kinds.ScoreColumnKind, scoreIndex, ref kind);
Assert.True(ReadOnlyMemoryUtils.EqualsStr(MetadataUtils.Const.ScoreColumnKind.AnomalyDetection, kind),
$"Anomaly detection ensemble scored data column type should be '{MetadataUtils.Const.ScoreColumnKind.AnomalyDetection}', but is instead '{kind}'");
diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestModelLoad.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestModelLoad.cs
index 8891b49d22..87f77a9a3f 100644
--- a/test/Microsoft.ML.Core.Tests/UnitTests/TestModelLoad.cs
+++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestModelLoad.cs
@@ -46,8 +46,8 @@ public void LoadOldConcatTransformModel()
Assert.Equal("Label", result.Schema[0].Name);
Assert.Equal("Features", result.Schema[1].Name);
Assert.Equal("Features", result.Schema[2].Name);
- Assert.Equal(9, result.Schema[1].Type.VectorSize);
- Assert.Equal(18, result.Schema[2].Type.VectorSize);
+ Assert.Equal(9, (result.Schema[1].Type as VectorType)?.Size);
+ Assert.Equal(18, (result.Schema[2].Type as VectorType)?.Size);
}
}
}
diff --git a/test/Microsoft.ML.Predictor.Tests/TestTransposer.cs b/test/Microsoft.ML.Predictor.Tests/TestTransposer.cs
index 4ce87f070b..c6cdbf543a 100644
--- a/test/Microsoft.ML.Predictor.Tests/TestTransposer.cs
+++ b/test/Microsoft.ML.Predictor.Tests/TestTransposer.cs
@@ -24,16 +24,18 @@ private static T[] NaiveTranspose(IDataView view, int col)
{
var type = view.Schema.GetColumnType(col);
int rc = checked((int)DataViewUtils.ComputeRowCount(view));
- Assert.True(type.ItemType.RawType == typeof(T));
- Assert.True(type.ValueCount > 0);
- T[] retval = new T[rc * type.ValueCount];
+ var vecType = type as VectorType;
+ var itemType = vecType?.ItemType ?? type;
+ Assert.Equal(typeof(T), itemType.RawType);
+ Assert.NotEqual(0, vecType?.Size);
+ T[] retval = new T[rc * (vecType?.Size ?? 1)];
using (var cursor = view.GetRowCursor(c => c == col))
{
- if (type.IsVector)
+ if (type is VectorType)
{
var getter = cursor.GetGetter>(col);
- VBuffer temp = default(VBuffer);
+ VBuffer temp = default;
int offset = 0;
while (cursor.MoveNext())
{
@@ -60,21 +62,20 @@ private static T[] NaiveTranspose(IDataView view, int col)
private static void TransposeCheckHelper(IDataView view, int viewCol, ITransposeDataView trans)
{
int col = viewCol;
- var type = trans.TransposeSchema.GetSlotType(col);
- var colType = trans.Schema.GetColumnType(col);
+ VectorType type = trans.TransposeSchema.GetSlotType(col);
+ ColumnType colType = trans.Schema.GetColumnType(col);
Assert.Equal(view.Schema.GetColumnName(viewCol), trans.Schema.GetColumnName(col));
- var expectedType = view.Schema.GetColumnType(viewCol);
- // Unfortunately can't use equals because column type equality is a simple reference comparison. :P
+ ColumnType expectedType = view.Schema.GetColumnType(viewCol);
Assert.Equal(expectedType, colType);
- Assert.Equal(DataViewUtils.ComputeRowCount(view), (long)type.VectorSize);
string desc = string.Format("Column {0} named '{1}'", col, trans.Schema.GetColumnName(col));
+ Assert.Equal(DataViewUtils.ComputeRowCount(view), (long)type.Size);
Assert.True(typeof(T) == type.ItemType.RawType, $"{desc} had wrong type for slot cursor");
- Assert.True(type.IsVector, $"{desc} expected to be vector but is not");
- Assert.True(type.VectorSize > 0, $"{desc} expected to be known sized vector but is not");
- Assert.True(0 != colType.ValueCount, $"{desc} expected to have fixed size, but does not");
- int rc = type.VectorSize;
+ Assert.True(type.Size > 0, $"{desc} expected to be known sized vector but is not");
+ int valueCount = (colType as VectorType)?.Size ?? 1;
+ Assert.True(0 != valueCount, $"{desc} expected to have fixed size, but does not");
+ int rc = type.Size;
T[] expectedVals = NaiveTranspose(view, viewCol);
- T[] vals = new T[rc * colType.ValueCount];
+ T[] vals = new T[rc * valueCount];
Contracts.Assert(vals.Length == expectedVals.Length);
using (var cursor = trans.GetSlotCursor(col))
{
@@ -89,7 +90,7 @@ private static void TransposeCheckHelper(IDataView view, int viewCol, ITransp
temp.CopyTo(vals, offset);
offset += rc;
}
- Assert.True(colType.ValueCount == offset / rc, $"{desc} slot cursor yielded fewer than expected values");
+ Assert.True(valueCount == offset / rc, $"{desc} slot cursor yielded fewer than expected values");
}
for (int i = 0; i < vals.Length; ++i)
Assert.Equal(expectedVals[i], vals[i]);
diff --git a/test/Microsoft.ML.StaticPipelineTesting/ImageAnalyticsTests.cs b/test/Microsoft.ML.StaticPipelineTesting/ImageAnalyticsTests.cs
index ea6fee04d3..3b28fc1cfa 100644
--- a/test/Microsoft.ML.StaticPipelineTesting/ImageAnalyticsTests.cs
+++ b/test/Microsoft.ML.StaticPipelineTesting/ImageAnalyticsTests.cs
@@ -28,13 +28,13 @@ public void SimpleImageSmokeTest()
var schema = reader.AsDynamic.GetOutputSchema();
Assert.True(schema.TryGetColumnIndex("Data", out int col), "Could not find 'Data' column");
var type = schema.GetColumnType(col);
- Assert.True(type.IsKnownSizeVector, $"Type was supposed to be known size vector but was instead '{type}'");
- var vecType = type.AsVector;
+ var vecType = type as VectorType;
+ Assert.True(vecType?.Size > 0, $"Type was supposed to be known size vector but was instead '{type}'");
Assert.Equal(NumberType.R4, vecType.ItemType);
- Assert.Equal(3, vecType.DimCount);
- Assert.Equal(3, vecType.GetDim(0));
- Assert.Equal(8, vecType.GetDim(1));
- Assert.Equal(10, vecType.GetDim(2));
+ Assert.Equal(3, vecType.Dimensions.Length);
+ Assert.Equal(3, vecType.Dimensions[0]);
+ Assert.Equal(8, vecType.Dimensions[1]);
+ Assert.Equal(10, vecType.Dimensions[2]);
var readAsImage = TextLoader.CreateReader(env,
ctx => ctx.LoadText(0).LoadAsImage());
diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs
index 3a3eff6289..de118ca2f5 100644
--- a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs
+++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs
@@ -210,7 +210,7 @@ public void AssertStaticSimple()
var schema = SimpleSchemaUtils.Create(env,
P("hello", TextType.Instance),
P("my", new VectorType(NumberType.I8, 5)),
- P("friend", new KeyType(DataKind.U4, 0, 3)));
+ P("friend", new KeyType(typeof(uint), 0, 3)));
var view = new EmptyDataView(env, schema);
view.AssertStatic(env, c => new
@@ -234,7 +234,7 @@ public void AssertStaticSimpleFailure()
var schema = SimpleSchemaUtils.Create(env,
P("hello", TextType.Instance),
P("my", new VectorType(NumberType.I8, 5)),
- P("friend", new KeyType(DataKind.U4, 0, 3)));
+ P("friend", new KeyType(typeof(uint), 0, 3)));
var view = new EmptyDataView(env, schema);
Assert.ThrowsAny(() =>
@@ -269,23 +269,23 @@ public void AssertStaticKeys()
var metaValues1 = new VBuffer>(3, new[] { "a".AsMemory(), "b".AsMemory(), "c".AsMemory() });
var meta1 = RowColumnUtils.GetColumn(MetadataUtils.Kinds.KeyValues, new VectorType(TextType.Instance, 3), ref metaValues1);
uint value1 = 2;
- var col1 = RowColumnUtils.GetColumn("stay", new KeyType(DataKind.U4, 0, 3), ref value1, RowColumnUtils.GetRow(counted, meta1));
+ var col1 = RowColumnUtils.GetColumn("stay", new KeyType(typeof(uint), 0, 3), ref value1, RowColumnUtils.GetRow(counted, meta1));
// Next the case where those values are ints.
var metaValues2 = new VBuffer(3, new int[] { 1, 2, 3, 4 });
var meta2 = RowColumnUtils.GetColumn(MetadataUtils.Kinds.KeyValues, new VectorType(NumberType.I4, 4), ref metaValues2);
var value2 = new VBuffer(2, 0, null, null);
- var col2 = RowColumnUtils.GetColumn("awhile", new VectorType(new KeyType(DataKind.U1, 2, 4), 2), ref value2, RowColumnUtils.GetRow(counted, meta2));
+ var col2 = RowColumnUtils.GetColumn("awhile", new VectorType(new KeyType(typeof(byte), 2, 4), 2), ref value2, RowColumnUtils.GetRow(counted, meta2));
// Then the case where a value of that kind exists, but is of not of the right kind, in which case it should not be identified as containing that metadata.
var metaValues3 = (float)2;
var meta3 = RowColumnUtils.GetColumn(MetadataUtils.Kinds.KeyValues, NumberType.R4, ref metaValues3);
var value3 = (ushort)1;
- var col3 = RowColumnUtils.GetColumn("and", new KeyType(DataKind.U2, 0, 2), ref value3, RowColumnUtils.GetRow(counted, meta3));
+ var col3 = RowColumnUtils.GetColumn("and", new KeyType(typeof(ushort), 0, 2), ref value3, RowColumnUtils.GetRow(counted, meta3));
// Then a final case where metadata of that kind is actaully simply altogether absent.
var value4 = new VBuffer(5, 0, null, null);
- var col4 = RowColumnUtils.GetColumn("listen", new VectorType(new KeyType(DataKind.U4, 0, 2)), ref value4);
+ var col4 = RowColumnUtils.GetColumn("listen", new VectorType(new KeyType(typeof(uint), 0, 2)), ref value4);
// Finally compose a trivial data view out of all this.
var row = RowColumnUtils.GetRow(counted, col1, col2, col3, col4);
@@ -456,9 +456,9 @@ public void ToKey()
Assert.True(schema.TryGetColumnIndex("valuesKey", out int valuesCol));
Assert.True(schema.TryGetColumnIndex("valuesKeyKey", out int valuesKeyCol));
- Assert.Equal(3, schema.GetColumnType(labelCol).KeyCount);
- Assert.True(schema.GetColumnType(valuesCol).ItemType.IsKey);
- Assert.True(schema.GetColumnType(valuesKeyCol).ItemType.IsKey);
+ Assert.Equal(3, (schema.GetColumnType(labelCol) as KeyType)?.Count);
+ Assert.True(schema.GetColumnType(valuesCol) is VectorType valuesVecType && valuesVecType.ItemType is KeyType);
+ Assert.True(schema.GetColumnType(valuesKeyCol) is VectorType valuesKeyVecType && valuesKeyVecType.ItemType is KeyType);
var labelKeyType = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, labelCol);
var valuesKeyType = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, valuesCol);
@@ -466,9 +466,9 @@ public void ToKey()
Assert.NotNull(labelKeyType);
Assert.NotNull(valuesKeyType);
Assert.NotNull(valuesKeyKeyType);
- Assert.True(labelKeyType.IsVector && labelKeyType.ItemType == TextType.Instance);
- Assert.True(valuesKeyType.IsVector && valuesKeyType.ItemType == NumberType.Float);
- Assert.True(valuesKeyKeyType.IsVector && valuesKeyKeyType.ItemType == NumberType.Float);
+ Assert.True(labelKeyType is VectorType labelVecType && labelVecType.ItemType == TextType.Instance);
+ Assert.True(valuesKeyType is VectorType valuesVecType2 && valuesVecType2.ItemType == NumberType.Float);
+ Assert.True(valuesKeyKeyType is VectorType valuesKeyVecType2 && valuesKeyVecType2.ItemType == NumberType.Float);
// Because they're over exactly the same data, they ought to have the same cardinality and everything.
Assert.True(valuesKeyKeyType.Equals(valuesKeyType));
}
@@ -501,9 +501,9 @@ public void ConcatWith()
for (int i = 0; i < idx.Length; ++i)
{
var type = schema.GetColumnType(idx[i]);
- Assert.True(type.VectorSize > 0, $"Col c{i} had unexpected type {type}");
- types[i] = type.AsVector;
- Assert.Equal(expectedLen[i], type.VectorSize);
+ types[i] = type as VectorType;
+ Assert.True(types[i]?.Size > 0, $"Col c{i} had unexpected type {type}");
+ Assert.Equal(expectedLen[i], types[i].Size);
}
Assert.Equal(TextType.Instance, types[0].ItemType);
Assert.Equal(TextType.Instance, types[1].ItemType);
@@ -533,12 +533,11 @@ public void Tokenize()
Assert.True(schema.TryGetColumnIndex("tokens", out int tokensCol));
var type = schema.GetColumnType(tokensCol);
- Assert.True(type.IsVector && !type.IsKnownSizeVector && type.ItemType.IsText);
-
+ Assert.True(type is VectorType vecType && vecType.Size == 0 && vecType.ItemType == TextType.Instance);
Assert.True(schema.TryGetColumnIndex("chars", out int charsCol));
type = schema.GetColumnType(charsCol);
- Assert.True(type.IsVector && !type.IsKnownSizeVector && type.ItemType.IsKey);
- Assert.True(type.ItemType.AsKey.RawKind == DataKind.U2);
+ Assert.True(type is VectorType vecType2 && vecType2.Size == 0 && vecType2.ItemType is KeyType
+ && vecType2.ItemType.RawType == typeof(ushort));
}
[Fact]
@@ -563,11 +562,11 @@ public void NormalizeTextAndRemoveStopWords()
Assert.True(schema.TryGetColumnIndex("words_without_stopwords", out int stopwordsCol));
var type = schema.GetColumnType(stopwordsCol);
- Assert.True(type.IsVector && !type.IsKnownSizeVector && type.ItemType.IsText);
+ Assert.True(type is VectorType vecType && vecType.Size == 0 && vecType.ItemType == TextType.Instance);
Assert.True(schema.TryGetColumnIndex("normalized_text", out int normTextCol));
type = schema.GetColumnType(normTextCol);
- Assert.True(type.IsText && !type.IsVector);
+ Assert.Equal(TextType.Instance, type);
}
[Fact]
@@ -592,11 +591,11 @@ public void ConvertToWordBag()
Assert.True(schema.TryGetColumnIndex("bagofword", out int bagofwordCol));
var type = schema.GetColumnType(bagofwordCol);
- Assert.True(type.IsVector && type.IsKnownSizeVector && type.ItemType.IsNumber);
+ Assert.True(type is VectorType vecType && vecType.Size > 0&& vecType.ItemType is NumberType);
Assert.True(schema.TryGetColumnIndex("bagofhashedword", out int bagofhashedwordCol));
type = schema.GetColumnType(bagofhashedwordCol);
- Assert.True(type.IsVector && type.IsKnownSizeVector && type.ItemType.IsNumber);
+ Assert.True(type is VectorType vecType2 && vecType2.Size > 0 && vecType2.ItemType is NumberType);
}
[Fact]
@@ -621,11 +620,11 @@ public void Ngrams()
Assert.True(schema.TryGetColumnIndex("ngrams", out int ngramsCol));
var type = schema.GetColumnType(ngramsCol);
- Assert.True(type.IsVector && type.IsKnownSizeVector && type.ItemType.IsNumber);
+ Assert.True(type is VectorType vecType && vecType.Size > 0 && vecType.ItemType is NumberType);
Assert.True(schema.TryGetColumnIndex("ngramshash", out int ngramshashCol));
type = schema.GetColumnType(ngramshashCol);
- Assert.True(type.IsVector && type.IsKnownSizeVector && type.ItemType.IsNumber);
+ Assert.True(type is VectorType vecType2 && vecType2.Size > 0 && vecType2.ItemType is NumberType);
}
@@ -652,19 +651,19 @@ public void LpGcNormAndWhitening()
Assert.True(schema.TryGetColumnIndex("lpnorm", out int lpnormCol));
var type = schema.GetColumnType(lpnormCol);
- Assert.True(type.IsVector && type.IsKnownSizeVector && type.ItemType.IsNumber);
+ Assert.True(type is VectorType vecType && vecType.Size > 0 && vecType.ItemType is NumberType);
Assert.True(schema.TryGetColumnIndex("gcnorm", out int gcnormCol));
type = schema.GetColumnType(gcnormCol);
- Assert.True(type.IsVector && type.IsKnownSizeVector && type.ItemType.IsNumber);
+ Assert.True(type is VectorType vecType2 && vecType2.Size > 0 && vecType2.ItemType is NumberType);
Assert.True(schema.TryGetColumnIndex("zcawhitened", out int zcawhitenedCol));
type = schema.GetColumnType(zcawhitenedCol);
- Assert.True(type.IsVector && type.IsKnownSizeVector && type.ItemType.IsNumber);
+ Assert.True(type is VectorType vecType3 && vecType3.Size > 0 && vecType3.ItemType is NumberType);
Assert.True(schema.TryGetColumnIndex("pcswhitened", out int pcswhitenedCol));
type = schema.GetColumnType(pcswhitenedCol);
- Assert.True(type.IsVector && type.IsKnownSizeVector && type.ItemType.IsNumber);
+ Assert.True(type is VectorType vecType4 && vecType4.Size > 0 && vecType4.ItemType is NumberType);
}
[Fact]
@@ -691,8 +690,8 @@ public void LdaTopicModel()
Assert.True(schema.TryGetColumnIndex("topics", out int topicsCol));
var type = schema.GetColumnType(topicsCol);
- Assert.True(type.IsVector && type.IsKnownSizeVector && type.ItemType.IsNumber);
- }
+ Assert.True(type is VectorType vecType && vecType.Size > 0 && vecType.ItemType is NumberType);
+}
[Fact]
public void FeatureSelection()
@@ -716,11 +715,11 @@ public void FeatureSelection()
Assert.True(schema.TryGetColumnIndex("bag_of_words_count", out int bagofwordCountCol));
var type = schema.GetColumnType(bagofwordCountCol);
- Assert.True(type.IsVector && type.IsKnownSizeVector && type.ItemType.IsNumber);
+ Assert.True(type is VectorType vecType && vecType.Size > 0 && vecType.ItemType is NumberType);
Assert.True(schema.TryGetColumnIndex("bag_of_words_mi", out int bagofwordMiCol));
type = schema.GetColumnType(bagofwordMiCol);
- Assert.True(type.IsVector && type.IsKnownSizeVector && type.ItemType.IsNumber);
+ Assert.True(type is VectorType vecType2 && vecType2.Size > 0 && vecType2.ItemType is NumberType);
}
[Fact]
@@ -773,7 +772,7 @@ public void PrincipalComponentAnalysis()
Assert.True(schema.TryGetColumnIndex("pca", out int pcaCol));
var type = schema.GetColumnType(pcaCol);
- Assert.True(type.IsVector && type.IsKnownSizeVector && type.ItemType.IsNumber);
+ Assert.True(type is VectorType vecType && vecType.Size > 0 && vecType.ItemType is NumberType);
}
[Fact]
@@ -844,20 +843,20 @@ public void TextNormalizeStatic()
Assert.True(schema.TryGetColumnIndex("norm", out int norm));
var type = schema.GetColumnType(norm);
- Assert.True(!type.IsVector && type.ItemType.IsText);
+ Assert.True(type is TextType);
Assert.True(schema.TryGetColumnIndex("norm_Upper", out int normUpper));
type = schema.GetColumnType(normUpper);
- Assert.True(!type.IsVector && type.ItemType.IsText);
+ Assert.True(type is TextType);
Assert.True(schema.TryGetColumnIndex("norm_KeepDiacritics", out int diacritics));
type = schema.GetColumnType(diacritics);
- Assert.True(!type.IsVector && type.ItemType.IsText);
+ Assert.True(type is TextType);
Assert.True(schema.TryGetColumnIndex("norm_NoPuctuations", out int punct));
type = schema.GetColumnType(punct);
- Assert.True(!type.IsVector && type.ItemType.IsText);
+ Assert.True(type is TextType);
Assert.True(schema.TryGetColumnIndex("norm_NoNumbers", out int numbers));
type = schema.GetColumnType(numbers);
- Assert.True(!type.IsVector && type.ItemType.IsText);
+ Assert.True(type is TextType);
}
[Fact]
@@ -876,8 +875,7 @@ public void TestPcaStatic()
Assert.True(schema.TryGetColumnIndex("pca", out int pca));
var type = schema[pca].Type;
- Assert.True(type.IsVector && type.ItemType.RawKind == DataKind.R4);
- Assert.True(type.VectorSize == 5);
+ Assert.Equal(new VectorType(NumberType.R4, 5), type);
}
[Fact]
@@ -900,14 +898,13 @@ public void TestConvertStatic()
Assert.True(schema.TryGetColumnIndex("floatLabel", out int floatLabel));
var type = schema[floatLabel].Type;
- Assert.True(!type.IsVector && type.ItemType.RawKind == DataKind.R4);
+ Assert.Equal(NumberType.R4, type);
Assert.True(schema.TryGetColumnIndex("txtFloat", out int txtFloat));
type = schema[txtFloat].Type;
- Assert.True(!type.IsVector && type.ItemType.RawKind == DataKind.R4);
+ Assert.Equal(NumberType.R4, type);
Assert.True(schema.TryGetColumnIndex("num", out int num));
type = schema[num].Type;
- Assert.True(type.IsVector && type.ItemType.RawKind == DataKind.R4);
- Assert.True(type.VectorSize == 3);
+ Assert.Equal(new VectorType(NumberType.R4, 3), type);
}
}
}
\ No newline at end of file
diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
index 503ec4c3c7..4cbbafb346 100644
--- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
+++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
@@ -194,13 +194,13 @@ public void TensorFlowInputsOutputsSchemaTest()
var schema = TensorFlowUtils.GetModelSchema(env, model_location);
Assert.Equal(86, schema.ColumnCount);
Assert.True(schema.TryGetColumnIndex("Placeholder", out int col));
- var type = schema.GetColumnType(col).AsVector;
- Assert.Equal(2, type.DimCount);
- Assert.Equal(28, type.GetDim(0));
- Assert.Equal(28, type.GetDim(1));
+ var type = (VectorType)schema.GetColumnType(col);
+ Assert.Equal(2, type.Dimensions.Length);
+ Assert.Equal(28, type.Dimensions[0]);
+ Assert.Equal(28, type.Dimensions[1]);
var metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.OpType, col);
Assert.NotNull(metadataType);
- Assert.True(metadataType.IsText);
+ Assert.True(metadataType is TextType);
ReadOnlyMemory opType = default;
schema.GetMetadata(TensorFlowUtils.OpType, col, ref opType);
Assert.Equal("Placeholder", opType.ToString());
@@ -208,15 +208,11 @@ public void TensorFlowInputsOutputsSchemaTest()
Assert.Null(metadataType);
Assert.True(schema.TryGetColumnIndex("conv2d/Conv2D/ReadVariableOp", out col));
- type = schema.GetColumnType(col).AsVector;
- Assert.Equal(4, type.DimCount);
- Assert.Equal(5, type.GetDim(0));
- Assert.Equal(5, type.GetDim(1));
- Assert.Equal(1, type.GetDim(2));
- Assert.Equal(32, type.GetDim(3));
+ type = (VectorType)schema.GetColumnType(col);
+ Assert.Equal(new[] { 5, 5, 1, 32 }, type.Dimensions);
metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.OpType, col);
Assert.NotNull(metadataType);
- Assert.True(metadataType.IsText);
+ Assert.True(metadataType is TextType);
schema.GetMetadata(TensorFlowUtils.OpType, col, ref opType);
Assert.Equal("Identity", opType.ToString());
metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.InputOps, col);
@@ -227,14 +223,11 @@ public void TensorFlowInputsOutputsSchemaTest()
Assert.Equal("conv2d/kernel", inputOps.Values[0].ToString());
Assert.True(schema.TryGetColumnIndex("conv2d/Conv2D", out col));
- type = schema.GetColumnType(col).AsVector;
- Assert.Equal(3, type.DimCount);
- Assert.Equal(28, type.GetDim(0));
- Assert.Equal(28, type.GetDim(1));
- Assert.Equal(32, type.GetDim(2));
+ type = (VectorType)schema.GetColumnType(col);
+ Assert.Equal(new[] { 28, 28, 32 }, type.Dimensions);
metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.OpType, col);
Assert.NotNull(metadataType);
- Assert.True(metadataType.IsText);
+ Assert.True(metadataType is TextType);
schema.GetMetadata(TensorFlowUtils.OpType, col, ref opType);
Assert.Equal("Conv2D", opType.ToString());
metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.InputOps, col);
@@ -245,12 +238,11 @@ public void TensorFlowInputsOutputsSchemaTest()
Assert.Equal("conv2d/Conv2D/ReadVariableOp", inputOps.Values[1].ToString());
Assert.True(schema.TryGetColumnIndex("Softmax", out col));
- type = schema.GetColumnType(col).AsVector;
- Assert.Equal(1, type.DimCount);
- Assert.Equal(10, type.GetDim(0));
+ type = (VectorType)schema.GetColumnType(col);
+ Assert.Equal(new[] { 10 }, type.Dimensions);
metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.OpType, col);
Assert.NotNull(metadataType);
- Assert.True(metadataType.IsText);
+ Assert.True(metadataType is TextType);
schema.GetMetadata(TensorFlowUtils.OpType, col, ref opType);
Assert.Equal("Softmax", opType.ToString());
metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.InputOps, col);
@@ -265,10 +257,8 @@ public void TensorFlowInputsOutputsSchemaTest()
for (int i = 0; i < schema.ColumnCount; i++)
{
Assert.Equal(name.ToString(), schema.GetColumnName(i));
- type = schema.GetColumnType(i).AsVector;
- Assert.Equal(2, type.DimCount);
- Assert.Equal(2, type.GetDim(0));
- Assert.Equal(2, type.GetDim(1));
+ type = (VectorType)schema.GetColumnType(i);
+ Assert.Equal(new[] { 2, 2 }, type.Dimensions);
name++;
}
}
@@ -787,9 +777,9 @@ public void TensorFlowTransformCifar()
var tensorFlowModel = TensorFlowUtils.LoadTensorFlowModel(env, model_location);
var schema = tensorFlowModel.GetInputSchema();
Assert.True(schema.TryGetColumnIndex("Input", out int column));
- var type = schema.GetColumnType(column).AsVector;
- var imageHeight = type.GetDim(0);
- var imageWidth = type.GetDim(1);
+ var type = (VectorType)schema.GetColumnType(column);
+ var imageHeight = type.Dimensions[0];
+ var imageWidth = type.Dimensions[1];
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
@@ -853,9 +843,9 @@ public void TensorFlowTransformCifarSavedModel()
var tensorFlowModel = TensorFlowUtils.LoadTensorFlowModel(env, model_location);
var schema = tensorFlowModel.GetInputSchema();
Assert.True(schema.TryGetColumnIndex("Input", out int column));
- var type = schema.GetColumnType(column).AsVector;
- var imageHeight = type.GetDim(0);
- var imageWidth = type.GetDim(1);
+ var type = (VectorType)schema.GetColumnType(column);
+ var imageHeight = type.Dimensions[0];
+ var imageWidth = type.Dimensions[1];
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
diff --git a/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs b/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs
index 690bd0dd2f..2f2820692f 100644
--- a/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs
+++ b/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs
@@ -193,9 +193,9 @@ public void TestTensorFlowStaticWithSchema()
var tensorFlowModel = TensorFlowUtils.LoadTensorFlowModel(env, modelLocation);
var schema = tensorFlowModel.GetInputSchema();
Assert.True(schema.TryGetColumnIndex("Input", out int column));
- var type = schema.GetColumnType(column).AsVector;
- var imageHeight = type.GetDim(0);
- var imageWidth = type.GetDim(1);
+ var type = (VectorType)schema.GetColumnType(column);
+ var imageHeight = type.Dimensions[0];
+ var imageWidth = type.Dimensions[1];
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
diff --git a/test/Microsoft.ML.Tests/TermEstimatorTests.cs b/test/Microsoft.ML.Tests/TermEstimatorTests.cs
index b5f8ddbe7f..c543523972 100644
--- a/test/Microsoft.ML.Tests/TermEstimatorTests.cs
+++ b/test/Microsoft.ML.Tests/TermEstimatorTests.cs
@@ -146,7 +146,8 @@ void TestMetadataCopy()
result.Schema.TryGetColumnIndex("T", out int termIndex);
var names1 = default(VBuffer>);
var type1 = result.Schema.GetColumnType(termIndex);
- int size = type1.ItemType.IsKey ? type1.ItemType.KeyCount : -1;
+ var itemType1 = (type1 as VectorType)?.ItemType ?? type1;
+ int size = itemType1 is KeyType keyType ? keyType.Count : -1;
result.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, termIndex, ref names1);
Assert.True(names1.Count > 0);
}
diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs
index bfa22dce86..cb5a8b2748 100644
--- a/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs
+++ b/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs
@@ -194,12 +194,12 @@ public void MatrixFactorizationInMemoryData()
// Check if the expected types in the trained model are expected.
Assert.True(model.MatrixColumnIndexColumnName == "MatrixColumnIndex");
Assert.True(model.MatrixRowIndexColumnName == "MatrixRowIndex");
- Assert.True(model.MatrixColumnIndexColumnType.IsKey);
- Assert.True(model.MatrixRowIndexColumnType.IsKey);
- var matColKeyType = model.MatrixColumnIndexColumnType.AsKey;
+ Assert.True(model.MatrixColumnIndexColumnType is KeyType);
+ Assert.True(model.MatrixRowIndexColumnType is KeyType);
+ var matColKeyType = (KeyType)model.MatrixColumnIndexColumnType;
Assert.True(matColKeyType.Min == _synthesizedMatrixFirstColumnIndex);
Assert.True(matColKeyType.Count == _synthesizedMatrixColumnCount);
- var matRowKeyType = model.MatrixRowIndexColumnType.AsKey;
+ var matRowKeyType = (KeyType)model.MatrixRowIndexColumnType;
Assert.True(matRowKeyType.Min == _synthesizedMatrixFirstRowIndex);
Assert.True(matRowKeyType.Count == _synthesizedMatrixRowCount);
@@ -285,12 +285,12 @@ public void MatrixFactorizationInMemoryDataZeroBaseIndex()
// Check if the expected types in the trained model are expected.
Assert.True(model.MatrixColumnIndexColumnName == nameof(MatrixElementZeroBased.MatrixColumnIndex));
Assert.True(model.MatrixRowIndexColumnName == nameof(MatrixElementZeroBased.MatrixRowIndex));
- Assert.True(model.MatrixColumnIndexColumnType.IsKey);
- Assert.True(model.MatrixRowIndexColumnType.IsKey);
- var matColKeyType = model.MatrixColumnIndexColumnType.AsKey;
+ var matColKeyType = model.MatrixColumnIndexColumnType as KeyType;
+ Assert.NotNull(matColKeyType);
+ var matRowKeyType = model.MatrixRowIndexColumnType as KeyType;
+ Assert.NotNull(matRowKeyType);
Assert.True(matColKeyType.Min == 0);
Assert.True(matColKeyType.Count == _synthesizedMatrixColumnCount);
- var matRowKeyType = model.MatrixRowIndexColumnType.AsKey;
Assert.True(matRowKeyType.Min == 0);
Assert.True(matRowKeyType.Count == _synthesizedMatrixRowCount);
diff --git a/test/Microsoft.ML.Tests/Transformers/ConcatTests.cs b/test/Microsoft.ML.Tests/Transformers/ConcatTests.cs
index 9133bdbcf9..efb3833be3 100644
--- a/test/Microsoft.ML.Tests/Transformers/ConcatTests.cs
+++ b/test/Microsoft.ML.Tests/Transformers/ConcatTests.cs
@@ -55,13 +55,13 @@ ColumnType GetType(Schema schema, string name)
ColumnType t;
t = GetType(data.Schema, "f1");
- Assert.True(t.IsVector && t.ItemType == NumberType.R4 && t.VectorSize == 1);
+ Assert.True(t is VectorType vt1 && vt1.ItemType == NumberType.R4 && vt1.Size == 1);
t = GetType(data.Schema, "f2");
- Assert.True(t.IsVector && t.ItemType == NumberType.R4 && t.VectorSize == 2);
+ Assert.True(t is VectorType vt2 && vt2.ItemType == NumberType.R4 && vt2.Size == 2);
t = GetType(data.Schema, "f3");
- Assert.True(t.IsVector && t.ItemType == NumberType.R4 && t.VectorSize == 5);
+ Assert.True(t is VectorType vt3 && vt3.ItemType == NumberType.R4 && vt3.Size == 5);
t = GetType(data.Schema, "f4");
- Assert.True(t.IsVector && t.ItemType == NumberType.R4 && t.VectorSize == 0);
+ Assert.True(t is VectorType vt4 && vt4.ItemType == NumberType.R4 && vt4.Size == 0);
data = SelectColumnsTransform.CreateKeep(Env, data, "f1", "f2", "f3", "f4");
@@ -111,9 +111,9 @@ ColumnType GetType(Schema schema, string name)
ColumnType t;
t = GetType(data.Schema, "f2");
- Assert.True(t.IsVector && t.ItemType == NumberType.R4 && t.VectorSize == 2);
+ Assert.True(t is VectorType vt2 && vt2.ItemType == NumberType.R4 && vt2.Size == 2);
t = GetType(data.Schema, "f3");
- Assert.True(t.IsVector && t.ItemType == NumberType.R4 && t.VectorSize == 5);
+ Assert.True(t is VectorType vt3 && vt3.ItemType == NumberType.R4 && vt3.Size == 5);
data = SelectColumnsTransform.CreateKeep(Env, data, "f2", "f3");
diff --git a/test/Microsoft.ML.Tests/Transformers/CopyColumnEstimatorTests.cs b/test/Microsoft.ML.Tests/Transformers/CopyColumnEstimatorTests.cs
index 019103ff44..5e63e694ba 100644
--- a/test/Microsoft.ML.Tests/Transformers/CopyColumnEstimatorTests.cs
+++ b/test/Microsoft.ML.Tests/Transformers/CopyColumnEstimatorTests.cs
@@ -152,7 +152,8 @@ void TestMetadataCopy()
var names1 = default(VBuffer>);
var names2 = default(VBuffer>);
var type1 = result.Schema.GetColumnType(termIndex);
- int size = type1.ItemType.IsKey ? type1.ItemType.KeyCount : -1;
+ var itemType1 = (type1 as VectorType)?.ItemType ?? type1;
+ int size = (itemType1 as KeyType)?.Count ?? -1;
var type2 = result.Schema.GetColumnType(copyIndex);
result.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, termIndex, ref names1);
result.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, copyIndex, ref names2);
diff --git a/test/Microsoft.ML.Tests/Transformers/HashTests.cs b/test/Microsoft.ML.Tests/Transformers/HashTests.cs
index e2c2a0159f..a8812b1337 100644
--- a/test/Microsoft.ML.Tests/Transformers/HashTests.cs
+++ b/test/Microsoft.ML.Tests/Transformers/HashTests.cs
@@ -245,20 +245,20 @@ private void HashTestPositiveIntegerCore(ulong value, uint expected, uint expect
if (value <= byte.MaxValue)
{
HashTestCore((byte)value, NumberType.U1, expected, expectedOrdered, expectedOrdered3);
- HashTestCore((byte)value, new KeyType(DataKind.U1, 0, byte.MaxValue - 1), eKey, eoKey, e3Key);
+ HashTestCore((byte)value, new KeyType(typeof(byte), 0, byte.MaxValue - 1), eKey, eoKey, e3Key);
}
if (value <= ushort.MaxValue)
{
HashTestCore((ushort)value, NumberType.U2, expected, expectedOrdered, expectedOrdered3);
- HashTestCore((ushort)value, new KeyType(DataKind.U2, 0, ushort.MaxValue - 1), eKey, eoKey, e3Key);
+ HashTestCore((ushort)value, new KeyType(typeof(ushort), 0, ushort.MaxValue - 1), eKey, eoKey, e3Key);
}
if (value <= uint.MaxValue)
{
HashTestCore((uint)value, NumberType.U4, expected, expectedOrdered, expectedOrdered3);
- HashTestCore((uint)value, new KeyType(DataKind.U4, 0, int.MaxValue - 1), eKey, eoKey, e3Key);
+ HashTestCore((uint)value, new KeyType(typeof(uint), 0, int.MaxValue - 1), eKey, eoKey, e3Key);
}
HashTestCore(value, NumberType.U8, expected, expectedOrdered, expectedOrdered3);
- HashTestCore((ulong)value, new KeyType(DataKind.U8, 0, 0), eKey, eoKey, e3Key);
+ HashTestCore((ulong)value, new KeyType(typeof(ulong), 0, 0), eKey, eoKey, e3Key);
HashTestCore(new UInt128(value, 0), NumberType.UG, expected, expectedOrdered, expectedOrdered3);
diff --git a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs
index 0c909ef97b..b436c96149 100644
--- a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs
+++ b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs
@@ -271,7 +271,7 @@ public void LdaWorkout()
using (var fs = File.Create(outputPath))
DataSaverUtils.SaveDataView(ch, saver, savedData, fs, keepHidden: true);
- Assert.Equal(10, savedData.Schema.GetColumnType(0).VectorSize);
+ Assert.Equal(10, (savedData.Schema.GetColumnType(0) as VectorType)?.Size);
}
// Diabling this check due to the following issue with consitency of output.