diff --git a/src/Microsoft.ML.Api/CodeGenerationUtils.cs b/src/Microsoft.ML.Api/CodeGenerationUtils.cs index 7af0fb85ed..1c39f1b9dd 100644 --- a/src/Microsoft.ML.Api/CodeGenerationUtils.cs +++ b/src/Microsoft.ML.Api/CodeGenerationUtils.cs @@ -75,11 +75,11 @@ public static void AppendFieldDeclaration(CSharpCodeProvider codeProvider, Strin target.Append(indent); target.AppendFormat("public {0} {1}", generatedCsTypeName, fieldName); - if (appendInitializer && colType.IsKnownSizeVector && !useVBuffer) + if (appendInitializer && colType is VectorType vecColType && vecColType.Size > 0 && !useVBuffer) { Contracts.Assert(generatedCsTypeName.EndsWith("[]")); var csItemType = generatedCsTypeName.Substring(0, generatedCsTypeName.Length - 2); - target.AppendFormat(" = new {0}[{1}]", csItemType, colType.VectorSize); + target.AppendFormat(" = new {0}[{1}]", csItemType, vecColType.Size); } target.AppendLine(";"); } @@ -109,17 +109,13 @@ private static string GetBackingTypeName(ColumnType colType, bool useVBuffer, Li { Contracts.AssertValue(colType); Contracts.AssertValue(attributes); - if (colType.IsVector) + if (colType is VectorType vecColType) { - if (colType.IsKnownSizeVector) + if (vecColType.Size > 0) { // By default, arrays are assumed variable length, unless a [VectorType(dim1, dim2, ...)] // attribute is applied to the fields. - var vectorType = colType.AsVector; - var dimensions = new int[vectorType.DimCount]; - for (int i = 0; i < dimensions.Length; i++) - dimensions[i] = vectorType.GetDim(i); - attributes.Add(string.Format("[VectorType({0})]", string.Join(", ", dimensions))); + attributes.Add(string.Format("[VectorType({0})]", string.Join(", ", vecColType.Dimensions))); } var itemType = GetBackingTypeName(colType.ItemType, false, attributes); diff --git a/src/Microsoft.ML.Core/Data/ColumnType.cs b/src/Microsoft.ML.Core/Data/ColumnType.cs index 69a6505c51..83b5caa27d 100644 --- a/src/Microsoft.ML.Core/Data/ColumnType.cs +++ b/src/Microsoft.ML.Core/Data/ColumnType.cs @@ -5,6 +5,7 @@ #pragma warning disable 420 // volatile with Interlocked.CompareExchange using System; +using System.Collections.Immutable; using System.Linq; using System.Reflection; using System.Text; @@ -14,41 +15,36 @@ namespace Microsoft.ML.Runtime.Data { /// - /// ColumnType is the abstract base class for all types in the IDataView type system. + /// This is the abstract base class for all types in the type system. /// public abstract class ColumnType : IEquatable { - private readonly Type _rawType; - private readonly DataKind _rawKind; - - // We cache these for speed and code size. - private readonly bool _isPrimitive; - private readonly bool _isVector; - private readonly bool _isNumber; - private readonly bool _isKey; - - // This private constructor sets all the _isXxx flags. It is invoked by other ctors. + // This private constructor sets all the IsXxx flags. It is invoked by other ctors. private ColumnType() { - _isPrimitive = this is PrimitiveType; - _isVector = this is VectorType; - _isNumber = this is NumberType; - _isKey = this is KeyType; + IsPrimitive = this is PrimitiveType; + IsVector = this is VectorType; + IsNumber = this is NumberType; + IsKey = this is KeyType; } - protected ColumnType(Type rawType) + /// + /// Constructor for extension types, which must be either or . + /// + private protected ColumnType(Type rawType) : this() { Contracts.CheckValue(rawType, nameof(rawType)); - _rawType = rawType; - _rawType.TryGetDataKind(out _rawKind); + RawType = rawType; + RawType.TryGetDataKind(out var rawKind); + RawKind = rawKind; } /// - /// Internal sub types can pass both the rawType and rawKind values. This asserts that they - /// are consistent. + /// Internal sub types can pass both the and values. + /// This asserts that they are consistent. /// - internal ColumnType(Type rawType, DataKind rawKind) + private protected ColumnType(Type rawType, DataKind rawKind) : this() { Contracts.AssertValue(rawType); @@ -57,42 +53,50 @@ internal ColumnType(Type rawType, DataKind rawKind) rawType.TryGetDataKind(out tmp); Contracts.Assert(tmp == rawKind); #endif - _rawType = rawType; - _rawKind = rawKind; + RawType = rawType; + RawKind = rawKind; } /// - /// The raw System.Type for this ColumnType. Note that this is the raw representation type - /// and NOT the complete information content of the ColumnType. Code should not assume that - /// a RawType uniquely identifiers a ColumnType. + /// The raw for this . Note that this is the raw representation type + /// and not the complete information content of the . Code should not assume that + /// a uniquely identifiers a . For example, most practical instances of + /// and will have a of , + /// but both are very different in the types of information conveyed in that number. /// - public Type RawType { get { return _rawType; } } + public Type RawType { get; } /// - /// The DataKind corresponding to RawType, if there is one (zero otherwise). It is equivalent - /// to the result produced by DataKindExtensions.TryGetDataKind(RawType, out kind). + /// The corresponding to , if there is one (default otherwise). + /// It is equivalent to the result produced by . + /// For external code it would be preferable to operate over . /// - public DataKind RawKind { get { return _rawKind; } } + [BestFriend] + internal DataKind RawKind { get; } /// - /// Whether this is a primitive type. + /// Whether this is a primitive type. External code should use is . /// - public bool IsPrimitive { get { return _isPrimitive; } } + [BestFriend] + internal bool IsPrimitive { get; } /// - /// Equivalent to "this as PrimitiveType". + /// Equivalent to as . /// - public PrimitiveType AsPrimitive { get { return _isPrimitive ? (PrimitiveType)this : null; } } + [BestFriend] + internal PrimitiveType AsPrimitive => IsPrimitive ? (PrimitiveType)this : null; /// - /// Whether this type is a standard numeric type. + /// Whether this type is a standard numeric type. External code should use is . /// - public bool IsNumber { get { return _isNumber; } } + [BestFriend] + internal bool IsNumber { get; } /// - /// Whether this type is the standard text type. + /// Whether this type is the standard text type. External code should use is . /// - public bool IsText + [BestFriend] + internal bool IsText { get { @@ -105,9 +109,10 @@ public bool IsText } /// - /// Whether this type is the standard boolean type. + /// Whether this type is the standard boolean type. External code should use is . /// - public bool IsBool + [BestFriend] + internal bool IsBool { get { @@ -120,118 +125,93 @@ public bool IsBool } /// - /// Whether this type is the standard type. - /// - public bool IsTimeSpan - { - get - { - Contracts.Assert((this == TimeSpanType.Instance) == (this is TimeSpanType)); - return this is TimeSpanType; - } - } - - /// - /// Whether this type is a . - /// - public bool IsDateTime - { - get - { - Contracts.Assert((this == DateTimeType.Instance) == (this is DateTimeType)); - return this is DateTimeType; - } - } - - /// - /// Whether this type is a - /// - public bool IsDateTimeZone - { - get - { - Contracts.Assert((this == DateTimeOffsetType.Instance) == (this is DateTimeOffsetType)); - return this is DateTimeOffsetType; - } - } - - /// - /// Whether this type is a standard scalar type completely determined by its RawType - /// (not a KeyType or StructureType, etc). + /// Whether this type is a standard scalar type completely determined by its + /// (not a or , etc). /// - public bool IsStandardScalar - { - get { return IsNumber || IsText || IsBool || IsTimeSpan || IsDateTime || IsDateTimeZone; } - } + [BestFriend] + internal bool IsStandardScalar => IsNumber || IsText || IsBool || + (this is TimeSpanType) || (this is DateTimeType) || (this is DateTimeOffsetType); /// /// Whether this type is a key type, which implies that the order of values is not significant, /// and arithmetic is non-sensical. A key type can define a cardinality. + /// External code should use is . /// - public bool IsKey { get { return _isKey; } } + [BestFriend] + internal bool IsKey { get; } /// - /// Equivalent to "this as KeyType". + /// Equivalent to as . /// - public KeyType AsKey { get { return _isKey ? (KeyType)this : null; } } + [BestFriend] + internal KeyType AsKey => IsKey ? (KeyType)this : null; /// - /// Zero return means either it's not a key type or the cardinality is unknown. + /// Zero return means either it's not a key type or the cardinality is unknown. External code should first + /// test whether this is of type , then if so get the property + /// from that. /// - public int KeyCount { get { return KeyCountCore; } } + [BestFriend] + internal int KeyCount => KeyCountCore; /// - /// The only sub-class that should override this is KeyType! + /// The only sub-class that should override this is . /// - internal virtual int KeyCountCore { get { return 0; } } + private protected virtual int KeyCountCore => 0; /// - /// Whether this is a vector type. + /// Whether this is a vector type. External code should just check directly against whether this type + /// is . /// - public bool IsVector { get { return _isVector; } } + [BestFriend] + internal bool IsVector { get; } /// - /// Equivalent to "this as VectorType". + /// Equivalent to as . /// - public VectorType AsVector { get { return _isVector ? (VectorType)this : null; } } + [BestFriend] + internal VectorType AsVector => IsVector ? (VectorType)this : null; /// - /// For non-vector types, this returns the column type itself (ie, return this). + /// For non-vector types, this returns the column type itself (i.e., return this). /// - public ColumnType ItemType { get { return ItemTypeCore; } } + [BestFriend] + internal ColumnType ItemType => ItemTypeCore; /// /// Whether this is a vector type with known size. Returns false for non-vector types. - /// Equivalent to VectorSize > 0. + /// Equivalent to > 0. /// - public bool IsKnownSizeVector { get { return VectorSize > 0; } } + [BestFriend] + internal bool IsKnownSizeVector => VectorSize > 0; /// - /// Zero return means either it's not a vector or the size is unknown. Equivalent to - /// IsVector ? ValueCount : 0 and to IsKnownSizeVector ? ValueCount : 0. + /// Zero return means either it's not a vector or the size is unknown. /// - public int VectorSize { get { return VectorSizeCore; } } + [BestFriend] + internal int VectorSize => VectorSizeCore; /// /// For non-vectors, this returns one. For unknown size vectors, it returns zero. /// Equivalent to IsVector ? VectorSize : 1. /// - public int ValueCount { get { return ValueCountCore; } } + [BestFriend] + internal int ValueCount => ValueCountCore; /// /// The only sub-class that should override this is VectorType! /// - internal virtual ColumnType ItemTypeCore { get { return this; } } + private protected virtual ColumnType ItemTypeCore => this; /// - /// The only sub-class that should override this is VectorType! + /// The only sub-class that should override this is ! /// - internal virtual int VectorSizeCore { get { return 0; } } + private protected virtual int VectorSizeCore => 0; /// /// The only sub-class that should override this is VectorType! /// - internal virtual int ValueCountCore { get { return 1; } } + private protected virtual int ValueCountCore => 1; // IEquatable interface recommends also to override base class implementations of // Object.Equals(Object) and GetHashCode. In classes below where Equals(ColumnType other) @@ -243,7 +223,8 @@ public bool IsStandardScalar /// Equivalent to calling Equals(ColumnType) for non-vector types. For vector type, /// returns true if current and other vector types have the same size and item type. /// - public bool SameSizeAndItemType(ColumnType other) + [BestFriend] + internal bool SameSizeAndItemType(ColumnType other) { if (other == null) return false; @@ -271,7 +252,7 @@ protected StructuredType(Type rawType) Contracts.Assert(!IsPrimitive); } - internal StructuredType(Type rawType, DataKind rawKind) + private protected StructuredType(Type rawType, DataKind rawKind) : base(rawType, rawKind) { Contracts.Assert(!IsPrimitive); @@ -289,17 +270,18 @@ protected PrimitiveType(Type rawType) { Contracts.Assert(IsPrimitive); Contracts.CheckParam(!typeof(IDisposable).IsAssignableFrom(RawType), nameof(rawType), - "A PrimitiveType cannot have a disposable RawType"); + "A " + nameof(PrimitiveType) + " cannot have a disposable " + nameof(RawType)); } - internal PrimitiveType(Type rawType, DataKind rawKind) + private protected PrimitiveType(Type rawType, DataKind rawKind) : base(rawType, rawKind) { Contracts.Assert(IsPrimitive); Contracts.Assert(!typeof(IDisposable).IsAssignableFrom(RawType)); } - public static PrimitiveType FromKind(DataKind kind) + [BestFriend] + internal static PrimitiveType FromKind(DataKind kind) { if (kind == DataKind.TX) return TextType.Instance; @@ -344,10 +326,7 @@ public override bool Equals(ColumnType other) return false; } - public override string ToString() - { - return "Text"; - } + public override string ToString() => "Text"; } /// @@ -486,51 +465,50 @@ public static NumberType R8 } } - public static NumberType Float - { - get { return R4; } - } + public static NumberType Float => R4; - public static new NumberType FromKind(DataKind kind) + [BestFriend] + internal static new NumberType FromKind(DataKind kind) { switch (kind) { - case DataKind.I1: - return I1; - case DataKind.U1: - return U1; - case DataKind.I2: - return I2; - case DataKind.U2: - return U2; - case DataKind.I4: - return I4; - case DataKind.U4: - return U4; - case DataKind.I8: - return I8; - case DataKind.U8: - return U8; - case DataKind.R4: - return R4; - case DataKind.R8: - return R8; - case DataKind.UG: - return UG; + case DataKind.I1: + return I1; + case DataKind.U1: + return U1; + case DataKind.I2: + return I2; + case DataKind.U2: + return U2; + case DataKind.I4: + return I4; + case DataKind.U4: + return U4; + case DataKind.I8: + return I8; + case DataKind.U8: + return U8; + case DataKind.R4: + return R4; + case DataKind.R8: + return R8; + case DataKind.UG: + return UG; } Contracts.Assert(false); - throw Contracts.Except("Bad data kind in NumericType.FromKind: {0}", kind); + throw Contracts.Except($"Bad data kind in {nameof(NumberType)}.{nameof(FromKind)}: {kind}"); } - public static NumberType FromType(Type type) + [BestFriend] + internal static NumberType FromType(Type type) { DataKind kind; if (type.TryGetDataKind(out kind)) return FromKind(kind); Contracts.Assert(false); - throw Contracts.Except("Bad data kind in NumericType.FromKind: {0}", kind); + throw Contracts.Except($"Bad data kind in {nameof(NumberType)}.{nameof(FromType)}: {kind}", kind); } public override bool Equals(ColumnType other) @@ -541,10 +519,7 @@ public override bool Equals(ColumnType other) return false; } - public override string ToString() - { - return _name; - } + public override string ToString() => _name; } /// @@ -608,10 +583,7 @@ public override bool Equals(ColumnType other) return false; } - public override string ToString() - { - return "DateTime"; - } + public override string ToString() => "DateTime"; } public sealed class DateTimeOffsetType : PrimitiveType @@ -640,10 +612,7 @@ public override bool Equals(ColumnType other) return false; } - public override string ToString() - { - return "DateTimeZone"; - } + public override string ToString() => "DateTimeZone"; } /// @@ -675,10 +644,7 @@ public override bool Equals(ColumnType other) return false; } - public override string ToString() - { - return "TimeSpan"; - } + public override string ToString() => "TimeSpan"; } /// @@ -699,26 +665,45 @@ public override string ToString() /// public sealed class KeyType : PrimitiveType { - private readonly bool _contiguous; - private readonly ulong _min; - // _count is only valid if _contiguous is true. Zero means unknown. - private readonly int _count; - - public KeyType(DataKind kind, ulong min, int count, bool contiguous = true) - : base(ToRawType(kind), kind) + private KeyType(Type type, DataKind kind, ulong min, int count, bool contiguous) + : base(type, kind) { + Contracts.AssertValue(type); + Contracts.Assert(kind.ToType() == type); + Contracts.CheckParam(min >= 0, nameof(min)); - Contracts.CheckParam(count >= 0, nameof(count), "count for key type must be non-negative"); + Contracts.CheckParam(count >= 0, nameof(count), "Must be non-negative."); Contracts.CheckParam((ulong)count <= ulong.MaxValue - min, nameof(count)); Contracts.CheckParam((ulong)count <= kind.ToMaxInt(), nameof(count)); - Contracts.CheckParam(contiguous || count == 0, nameof(count), "count must be 0 for non-contiguous"); + Contracts.CheckParam(contiguous || count == 0, nameof(count), "Must be 0 for non-contiguous"); - _contiguous = contiguous; - _min = min; - _count = count; + Contiguous = contiguous; + Min = min; + Count = count; Contracts.Assert(IsKey); } + public KeyType(Type type, ulong min, int count, bool contiguous = true) + : this(type, CheckRefRawType(type), min, count, contiguous) + { + } + + [BestFriend] + internal KeyType(DataKind kind, ulong min, int count, bool contiguous = true) + : this(ToRawType(kind), kind, min, count, contiguous) + { + } + + private static DataKind CheckRefRawType(Type type) + { + Contracts.CheckValue(type, nameof(type)); + Contracts.CheckParam(IsValidDataType(type), nameof(type)); + var result = type.TryGetDataKind(out var kind); + Contracts.Assert(result); + return kind; + + } + private static Type ToRawType(DataKind kind) { Contracts.CheckParam(IsValidDataKind(kind), nameof(kind)); @@ -726,31 +711,43 @@ private static Type ToRawType(DataKind kind) } /// - /// Returns true iff the given DataKind is valid for a KeyType. The valid ones are - /// U1, U2, U4, and U8, that is, the unsigned integer kinds. + /// Returns true iff the given DataKind is valid for a . The valid ones are + /// , , , and , + /// that is, the unsigned integer kinds. /// - public static bool IsValidDataKind(DataKind kind) + [BestFriend] + internal static bool IsValidDataKind(DataKind kind) { switch (kind) { - case DataKind.U1: - case DataKind.U2: - case DataKind.U4: - case DataKind.U8: - return true; - default: - return false; + case DataKind.U1: + case DataKind.U2: + case DataKind.U4: + case DataKind.U8: + return true; + default: + return false; } } - internal override int KeyCountCore { get { return _count; } } + /// + /// Returns true iff the given type is valid for a . The valid ones are + /// , , , and , that is, the unsigned integer types. + /// + public static bool IsValidDataType(Type type) + { + Contracts.CheckValue(type, nameof(type)); + return type == typeof(byte) || type == typeof(ushort) || type == typeof(uint) || type == typeof(ulong); + } + + private protected override int KeyCountCore => Count; /// /// This is the Min of the key type for display purposes and conversion to/from text. The values /// actually stored always start at 1 (for the smallest legal value), with zero being reserved /// for "not there"/"none". Typical Min values are 0 or 1, but can be any value >= 0. /// - public ulong Min { get { return _min; } } + public ulong Min { get; } /// /// If this key type has contiguous values and a known cardinality, Count is that cardinality. @@ -760,9 +757,9 @@ public static bool IsValidDataKind(DataKind kind) /// representation. Note that an id of 0 is used to represent the notion "none", which is /// typically mapped to a vector of all zeros (of length Count). /// - public int Count { get { return _count; } } + public int Count { get; } - public bool Contiguous { get { return _contiguous; } } + public bool Contiguous { get; } public override bool Equals(ColumnType other) { @@ -775,11 +772,11 @@ public override bool Equals(ColumnType other) if (RawKind != tmp.RawKind) return false; Contracts.Assert(RawType == tmp.RawType); - if (_contiguous != tmp._contiguous) + if (Contiguous != tmp.Contiguous) return false; - if (_min != tmp._min) + if (Min != tmp.Min) return false; - if (_count != tmp._count) + if (Count != tmp.Count) return false; return true; } @@ -791,17 +788,17 @@ public override bool Equals(object other) public override int GetHashCode() { - return Hashing.CombinedHash(RawKind.GetHashCode(), _contiguous, _min, _count); + return Hashing.CombinedHash(RawKind.GetHashCode(), Contiguous, Min, Count); } public override string ToString() { - if (_count > 0) - return string.Format("Key<{0}, {1}-{2}>", RawKind.GetString(), _min, _min + (ulong)_count - 1); - if (_contiguous) - return string.Format("Key<{0}, {1}-*>", RawKind.GetString(), _min); + if (Count > 0) + return string.Format("Key<{0}, {1}-{2}>", RawKind.GetString(), Min, Min + (ulong)Count - 1); + if (Contiguous) + return string.Format("Key<{0}, {1}-*>", RawKind.GetString(), Min); // This is the non-contiguous case - simply show the Min. - return string.Format("Key<{0}, Min:{1}>", RawKind.GetString(), _min); + return string.Format("Key<{0}, Min:{1}>", RawKind.GetString(), Min); } } @@ -810,74 +807,75 @@ public override string ToString() /// public sealed class VectorType : StructuredType { - private readonly PrimitiveType _itemType; - private readonly int _size; - - // The _sizes are the cumulative products of the _dims. These may be null, meaning that - // the information is naturally one dimensional. - private readonly int[] _sizes; - private readonly int[] _dims; + /// b + /// The dimensions. This will always have at least one item. All values will be non-negative. + /// As with , a zero value indicates that the vector type is considered to have + /// unknown length along that dimension. + /// + public ImmutableArray Dimensions { get; } + /// + /// Constructs a new single-dimensional vector type. + /// + /// The type of the items contained in the vector. + /// The size of the single dimension. public VectorType(PrimitiveType itemType, int size = 0) : base(GetRawType(itemType), 0) { Contracts.CheckParam(size >= 0, nameof(size)); - _itemType = itemType; - _size = size; + ItemType = itemType; + Size = size; + Dimensions = ImmutableArray.Create(Size); } - public VectorType(PrimitiveType itemType, params int[] dims) - : base(GetRawType(itemType), default(DataKind)) + /// + /// Constructs a potentially multi-dimensional vector type. + /// + /// The type of the items contained in the vector. + /// The dimensions. Note that, like , must be non-empty, with all + /// non-negative values. Also, because is the product of , the result of + /// multiplying all these values together must not overflow . + public VectorType(PrimitiveType itemType, params int[] dimensions) + : base(GetRawType(itemType), default) { - Contracts.CheckParam(Utils.Size(dims) > 0, nameof(dims)); - Contracts.CheckParam(dims.All(d => d >= 0), nameof(dims)); - - _itemType = itemType; + Contracts.CheckParam(Utils.Size(dimensions) > 0, nameof(dimensions)); + Contracts.CheckParam(dimensions.All(d => d >= 0), nameof(dimensions)); - if (dims.Length == 1) - _size = dims[0]; - else - { - _dims = new int[dims.Length]; - Array.Copy(dims, _dims, _dims.Length); - _size = ComputeSizes(_dims, out _sizes); - } + ItemType = itemType; + Dimensions = dimensions.ToImmutableArray(); + Size = ComputeSize(Dimensions); } /// - /// Creates a VectorType whose dimensionality information is the given template's information. + /// Creates a whose dimensionality information is the given 's information. /// - public VectorType(PrimitiveType itemType, VectorType template) - : base(GetRawType(itemType), default(DataKind)) + [BestFriend] + internal VectorType(PrimitiveType itemType, VectorType template) + : base(GetRawType(itemType), default) { Contracts.CheckValue(template, nameof(template)); - _itemType = itemType; - _size = template._size; - _sizes = template._sizes; - _dims = template._dims; + ItemType = itemType; + Dimensions = template.Dimensions; + Size = template.Size; } /// - /// Creates a VectorType whose dimensionality information is the given template's information - /// concatenated with the specified dims. + /// Creates a whose dimensionality information is the given 's information, + /// concatenated with the specified . /// - public VectorType(PrimitiveType itemType, VectorType template, params int[] dims) - : base(GetRawType(itemType), default(DataKind)) + [BestFriend] + internal VectorType(PrimitiveType itemType, VectorType template, params int[] dims) + : base(GetRawType(itemType), default) { Contracts.CheckValue(template, nameof(template)); + Contracts.CheckParam(Utils.Size(dims) > 0, nameof(dims)); + Contracts.CheckParam(dims.All(d => d >= 0), nameof(dims)); - _itemType = itemType; - - if (template._dims == null) - _dims = Utils.Concat(new int[] { template._size }, dims); - else - { - Contracts.Assert(template._dims.Length >= 2); - _dims = Utils.Concat(template._dims, dims); - } - _size = ComputeSizes(_dims, out _sizes); + ItemType = itemType; + Dimensions = template.Dimensions.AddRange(dims); + Size = ComputeSize(Dimensions); } private static Type GetRawType(PrimitiveType itemType) @@ -886,60 +884,48 @@ private static Type GetRawType(PrimitiveType itemType) return typeof(VBuffer<>).MakeGenericType(itemType.RawType); } - private static int ComputeSizes(int[] dims, out int[] sizes) + private static int ComputeSize(ImmutableArray dims) { - sizes = new int[dims.Length]; int size = 1; - for (int i = dims.Length; --i >= 0; ) - size = sizes[i] = checked(size * dims[i]); + for (int i = 0; i < dims.Length; ++i) + size = checked(size * dims[i]); return size; } - public int DimCount { get { return _dims != null ? _dims.Length : _size > 0 ? 1 : 0; } } - - public int GetDim(int idim) - { - if (_dims == null) - { - // That that if _size is zero, DimCount is zero, so this method is illegal - // to call. That case is caught by Check(_size > 0). - Contracts.Check(_size > 0); - Contracts.Assert(DimCount == 1); - Contracts.CheckParam(idim == 0, nameof(idim)); - return _size; - } - - Contracts.CheckParam(0 <= idim && idim < _dims.Length, nameof(idim)); - return _dims[idim]; - } + /// + /// The type of the items stored as values in vectors of this type. + /// + public new PrimitiveType ItemType { get; } - public new PrimitiveType ItemType { get { return _itemType; } } + /// + /// The size of the vector. A value of zero means it is a vector whose size is unknown. + /// A vector whose size is known should correspond to values that always have the same , + /// whereas one whose size is known may have values whose varies from record to record. + /// Note that this is always the product of the elements in . + /// + public int Size { get; } - internal override ColumnType ItemTypeCore { get { return _itemType; } } + private protected override ColumnType ItemTypeCore => ItemType; - internal override int VectorSizeCore { get { return _size; } } + private protected override int VectorSizeCore => Size; - internal override int ValueCountCore { get { return _size; } } + private protected override int ValueCountCore => Size; public override bool Equals(ColumnType other) { if (other == this) return true; - var tmp = other.AsVector; - if (tmp == null) + if (!(other is VectorType tmp)) return false; - if (!_itemType.Equals(tmp._itemType)) + if (!ItemType.Equals(tmp.ItemType)) return false; - if (_size != tmp._size) + if (Size != tmp.Size) return false; - int count = Utils.Size(_dims); - if (count != Utils.Size(tmp._dims)) + if (Dimensions.Length != tmp.Dimensions.Length) return false; - if (count == 0) - return true; - for (int i = 0; i < count; i++) + for (int i = 0; i < Dimensions.Length; i++) { - if (_dims[i] != tmp._dims[i]) + if (Dimensions[i] != tmp.Dimensions[i]) return false; } return true; @@ -952,47 +938,26 @@ public override bool Equals(object other) public override int GetHashCode() { - int hash = Hashing.CombinedHash(_itemType.GetHashCode(), _size); - int count = Utils.Size(_dims); - hash = Hashing.CombineHash(hash, count.GetHashCode()); - for (int i = 0; i < count; i++) - hash = Hashing.CombineHash(hash, _dims[i].GetHashCode()); + int hash = Hashing.CombinedHash(ItemType.GetHashCode(), Size); + hash = Hashing.CombineHash(hash, Dimensions.Length); + for (int i = 0; i < Dimensions.Length; i++) + hash = Hashing.CombineHash(hash, Dimensions[i].GetHashCode()); return hash; } - /// - /// Returns true if current has the same item type of other, and the size - /// of other is unknown or the current size is equal to the size of other. - /// - public bool IsSubtypeOf(VectorType other) - { - if (other == this) - return true; - if (other == null) - return false; - - // REVIEW: Perhaps we should allow the case when _itemType is - // a sub-type of other._itemType (in particular for key types) - if (!_itemType.Equals(other._itemType)) - return false; - if (other._size == 0 || _size == other._size) - return true; - return false; - } - public override string ToString() { var sb = new StringBuilder(); - sb.Append("Vec<").Append(_itemType); + sb.Append("Vec<").Append(ItemType); - if (_dims == null) + if (Dimensions.Length == 1) { - if (_size > 0) - sb.Append(", ").Append(_size); + if (Size > 0) + sb.Append(", ").Append(Size); } else { - foreach (var dim in _dims) + foreach (var dim in Dimensions) { sb.Append(", "); if (dim > 0) diff --git a/src/Microsoft.ML.Core/Data/MetadataUtils.cs b/src/Microsoft.ML.Core/Data/MetadataUtils.cs index 7ed3aecd9e..ee70bc732a 100644 --- a/src/Microsoft.ML.Core/Data/MetadataUtils.cs +++ b/src/Microsoft.ML.Core/Data/MetadataUtils.cs @@ -14,7 +14,7 @@ namespace Microsoft.ML.Runtime.Data { /// - /// Utilities for implementing and using the metadata API of ISchema. + /// Utilities for implementing and using the metadata API of . /// public static class MetadataUtils { @@ -433,7 +433,7 @@ public static bool TryGetCategoricalFeatureIndices(Schema schema, int colIndex, bool isValid = false; categoricalFeatures = null; - if (!schema.GetColumnType(colIndex).IsKnownSizeVector) + if (!(schema.GetColumnType(colIndex) is VectorType vecType && vecType.Size > 0)) return isValid; var type = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.CategoricalSlotRanges, colIndex); @@ -442,7 +442,7 @@ public static bool TryGetCategoricalFeatureIndices(Schema schema, int colIndex, VBuffer catIndices = default(VBuffer); schema.GetMetadata(MetadataUtils.Kinds.CategoricalSlotRanges, colIndex, ref catIndices); VBufferUtils.Densify(ref catIndices); - int columnSlotsCount = schema.GetColumnType(colIndex).AsVector.VectorSizeCore; + int columnSlotsCount = vecType.Size; if (catIndices.Length > 0 && catIndices.Length % 2 == 0 && catIndices.Length <= columnSlotsCount * 2) { int previousEndIndex = -1; diff --git a/src/Microsoft.ML.Core/Microsoft.ML.Core.csproj b/src/Microsoft.ML.Core/Microsoft.ML.Core.csproj index c0f8558822..dea0558bb8 100644 --- a/src/Microsoft.ML.Core/Microsoft.ML.Core.csproj +++ b/src/Microsoft.ML.Core/Microsoft.ML.Core.csproj @@ -12,6 +12,7 @@ + diff --git a/src/Microsoft.ML.Core/Properties/AssemblyInfo.cs b/src/Microsoft.ML.Core/Properties/AssemblyInfo.cs index 453d5d5d44..0ef400e485 100644 --- a/src/Microsoft.ML.Core/Properties/AssemblyInfo.cs +++ b/src/Microsoft.ML.Core/Properties/AssemblyInfo.cs @@ -3,5 +3,39 @@ // See the LICENSE file in the project root for more information. using System.Runtime.CompilerServices; +using Microsoft.ML; -[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TestFramework, PublicKey=002400000480000094000000060200000024000052534131000400000100010015c01ae1f50e8cc09ba9eac9147cf8fd9fce2cfe9f8dce4f7301c4132ca9fb50ce8cbf1df4dc18dd4d210e4345c744ecb3365ed327efdbc52603faa5e21daa11234c8c4a73e51f03bf192544581ebe107adee3a34928e39d04e524a9ce729d5090bfd7dad9d10c722c0def9ccc08ff0a03790e48bcd1f9b6c476063e1966a1c4")] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TestFramework" + PublicKey.TestValue)] + +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Legacy" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Maml" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.PipelineInference" + PublicKey.Value)] + +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Data" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Api" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Ensemble" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.FastTree" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.HalLearners" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.KMeansClustering" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.LightGBM" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Onnx" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.OnnxTransform" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.PCA" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Recommender" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Runtime.ImageAnalytics" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Scoring" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.StandardLearners" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TensorFlow" + PublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Transforms" + PublicKey.Value)] + +[assembly: WantsToBeBestFriends] + +namespace Microsoft.ML +{ + [BestFriend] + internal static class PublicKey + { + public const string Value = ", PublicKey=00240000048000009400000006020000002400005253413100040000010001004b86c4cb78549b34bab61a3b1800e23bfeb5b3ec390074041536a7e3cbd97f5f04cf0f857155a8928eaa29ebfd11cfbbad3ba70efea7bda3226c6a8d370a4cd303f714486b6ebc225985a638471e6ef571cc92a4613c00b8fa65d61ccee0cbe5f36330c9a01f4183559f1bef24cc2917c6d913e3a541333a1d05d9bed22b38cb"; + public const string TestValue = ", PublicKey=002400000480000094000000060200000024000052534131000400000100010015c01ae1f50e8cc09ba9eac9147cf8fd9fce2cfe9f8dce4f7301c4132ca9fb50ce8cbf1df4dc18dd4d210e4345c744ecb3365ed327efdbc52603faa5e21daa11234c8c4a73e51f03bf192544581ebe107adee3a34928e39d04e524a9ce729d5090bfd7dad9d10c722c0def9ccc08ff0a03790e48bcd1f9b6c476063e1966a1c4"; + } +} diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs index 05a776cc7b..2362741bc8 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs @@ -827,13 +827,13 @@ public VBufferCodec(CodecFactory factory, VectorType type, IValueCodec innerC public int WriteParameterization(Stream stream) { int total = _factory.WriteCodec(stream, _innerCodec); - int count = _type.DimCount; + int count = _type.Dimensions.Length; total += sizeof(int) * (1 + count); using (BinaryWriter writer = _factory.OpenBinaryWriter(stream)) { writer.Write(count); for (int i = 0; i < count; i++) - writer.Write(_type.GetDim(i)); + writer.Write(_type.Dimensions[i]); } return total; } @@ -1163,7 +1163,13 @@ private bool GetVBufferCodec(Stream definitionStream, out IValueCodec codec) type = new VectorType(itemType, dims); } else + { + // In prior times, in the case where the VectorType was of single rank, *and* of unknown length, + // then the vector type would be considered to have a dimension count of 0, for some reason. + // This can no longer occur, but in the case where we read an older file we have to account for + // the fact that nothing may have been written. type = new VectorType(itemType); + } } // Next create the vbuffer codec. Type codecType = typeof(VBufferCodec<>).MakeGenericType(itemType.RawType); diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs index 8e83f01ac1..db84057b79 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs @@ -97,17 +97,17 @@ protected ValueWriterBase(PrimitiveType type, int source, char sep) ValueMapper, StringBuilder> c = MapText; Conv = (ValueMapper)(Delegate)c; } - else if (type.IsTimeSpan) + else if (type is TimeSpanType) { ValueMapper c = MapTimeSpan; Conv = (ValueMapper)(Delegate)c; } - else if (type.IsDateTime) + else if (type is DateTimeType) { ValueMapper c = MapDateTime; Conv = (ValueMapper)(Delegate)c; } - else if (type.IsDateTimeZone) + else if (type is DateTimeOffsetType) { ValueMapper c = MapDateTimeZone; Conv = (ValueMapper)(Delegate)c; diff --git a/src/Microsoft.ML.Data/Transforms/NAFilter.cs b/src/Microsoft.ML.Data/Transforms/NAFilter.cs index 05ffea70b5..cd49e30b78 100644 --- a/src/Microsoft.ML.Data/Transforms/NAFilter.cs +++ b/src/Microsoft.ML.Data/Transforms/NAFilter.cs @@ -113,7 +113,7 @@ public NAFilter(IHostEnvironment env, Arguments args, IDataView input) var type = schema.GetColumnType(index); if (!TestType(type)) - throw Host.ExceptUserArg(nameof(args.Column), "Column '{0}' does not have compatible numeric type", src); + throw Host.ExceptUserArg(nameof(args.Column), $"Column '{src}' has type {type} which does not support missing values, so we cannot filter on them", src); _infos[i] = new ColInfo(index, type); _srcIndexToInfoIndex.Add(index, i); @@ -149,7 +149,7 @@ public NAFilter(IHost host, ModelLoadContext ctx, IDataView input) var type = schema.GetColumnType(index); if (!TestType(type)) - throw Host.Except("Column '{0}' does not have compatible numeric type", src); + throw Host.Except($"Column '{src}' has type {type} which does not support missing values, so we cannot filter on them", src); _infos[i] = new ColInfo(index, type); _srcIndexToInfoIndex.Add(index, i); @@ -187,32 +187,12 @@ private static bool TestType(ColumnType type) { Contracts.AssertValue(type); - var itemType = type.ItemType; - if (itemType.IsNumber) - { - switch (itemType.RawKind) - { - case DataKind.I1: - case DataKind.I2: - case DataKind.I4: - case DataKind.I8: - case DataKind.R4: - case DataKind.R8: - return true; - } - return false; - } - if (itemType.IsText) - return true; - if (itemType.IsBool) - return true; - if (itemType.IsKey) - return true; - if (itemType.IsTimeSpan) + var itemType = (type as VectorType)?.ItemType ?? type; + if (itemType == NumberType.R4) return true; - if (itemType.IsDateTime) + if (itemType == NumberType.R8) return true; - if (itemType.IsDateTimeZone) + if (itemType is KeyType) return true; return false; } @@ -287,15 +267,15 @@ public static Value Create(RowCursor cursor, ColInfo info) Contracts.AssertValue(info); MethodInfo meth; - if (!info.Type.IsVector) + if (info.Type is VectorType vecType) { - Func> d = CreateOne; - meth = d.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(info.Type.RawType); + Func> d = CreateVec; + meth = d.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(vecType.ItemType.RawType); } else { - Func> d = CreateVec; - meth = d.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(info.Type.ItemType.RawType); + Func> d = CreateOne; + meth = d.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(info.Type.RawType); } return (Value)meth.Invoke(null, new object[] { cursor, info }); } @@ -304,7 +284,7 @@ private static ValueOne CreateOne(RowCursor cursor, ColInfo info) { Contracts.AssertValue(cursor); Contracts.AssertValue(info); - Contracts.Assert(!info.Type.IsVector); + Contracts.Assert(!(info.Type is VectorType)); Contracts.Assert(info.Type.RawType == typeof(T)); var getSrc = cursor.Input.GetGetter(info.Index); @@ -316,8 +296,8 @@ private static ValueVec CreateVec(RowCursor cursor, ColInfo info) { Contracts.AssertValue(cursor); Contracts.AssertValue(info); - Contracts.Assert(info.Type.IsVector); - Contracts.Assert(info.Type.ItemType.RawType == typeof(T)); + Contracts.Assert(info.Type is VectorType); + Contracts.Assert(info.Type.RawType == typeof(VBuffer)); var getSrc = cursor.Input.GetGetter>(info.Index); var hasBad = Runtime.Data.Conversion.Conversions.Instance.GetHasMissingPredicate((VectorType)info.Type); diff --git a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs index 1bb0db3b4c..d19cf6f11b 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs @@ -442,15 +442,16 @@ protected override Delegate MakeGetter(IRow input, int iinfo, out Action dispose private ValueGetter> GetGetterCore(IRow input, int iinfo, out Action disposer) { var type = _types[iinfo]; - Contracts.Assert(type.DimCount == 3); + var dims = type.Dimensions; + Contracts.Assert(dims.Length == 3); var ex = _parent._columns[iinfo]; - int planes = ex.Interleave ? type.GetDim(2) : type.GetDim(0); - int height = ex.Interleave ? type.GetDim(0) : type.GetDim(1); - int width = ex.Interleave ? type.GetDim(1) : type.GetDim(2); + int planes = ex.Interleave ? dims[2] : dims[0]; + int height = ex.Interleave ? dims[0] : dims[1]; + int width = ex.Interleave ? dims[1] : dims[2]; - int size = type.ValueCount; + int size = type.Size; Contracts.Assert(size > 0); Contracts.Assert(size == planes * height * width); int cpix = height * width; diff --git a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs index c325f87a4d..c0bcb5318d 100644 --- a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs +++ b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs @@ -922,8 +922,8 @@ public static void Initialize(IHost host, int numThreads, IChannel ch, FeatureFl else { ArrayDataViewBuilder arrDv = new ArrayDataViewBuilder(host); - arrDv.AddColumn(DefaultColumnNames.Features, PrimitiveType.FromKind(DataKind.R4), clusters); - arrDv.AddColumn(DefaultColumnNames.Weight, PrimitiveType.FromKind(DataKind.R4), totalWeights); + arrDv.AddColumn(DefaultColumnNames.Features, NumberType.R4, clusters); + arrDv.AddColumn(DefaultColumnNames.Weight, NumberType.R4, totalWeights); var subDataViewCursorFactory = new FeatureFloatVectorCursor.Factory( new RoleMappedData(arrDv.GetDataView(), null, DefaultColumnNames.Features, weight: DefaultColumnNames.Weight), CursOpt.Weight | CursOpt.Features); long discard1; diff --git a/src/Microsoft.ML.Onnx/OnnxUtils.cs b/src/Microsoft.ML.Onnx/OnnxUtils.cs index 83623ab0ba..e7c06ee53a 100644 --- a/src/Microsoft.ML.Onnx/OnnxUtils.cs +++ b/src/Microsoft.ML.Onnx/OnnxUtils.cs @@ -368,8 +368,8 @@ public static ModelArgs GetModelArgs(ColumnType type, string colName, else if (type.ValueCount > 1) { var vec = type.AsVector; - for (int i = 0; i < vec.DimCount; i++) - dimsLocal.Add(vec.GetDim(i)); + for (int i = 0; i < vec.Dimensions.Length; i++) + dimsLocal.Add(vec.Dimensions[i]); } } //batch size. diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs index bacc9945b7..fb981a1bf2 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs @@ -828,20 +828,23 @@ public Mapper(IHostEnvironment env, TensorFlowTransform parent, ISchema inputSch throw _host.Except($"Column {_parent.Inputs[i]} doesn't exist"); var type = inputSchema.GetColumnType(_inputColIndices[i]); - if (type.IsVector && type.VectorSize == 0) - throw _host.Except($"Variable length input columns not supported"); + if (type is VectorType vecType && vecType.Size == 0) + throw _host.Except("Variable length input columns not supported"); - _isInputVector[i] = type.IsVector; + _isInputVector[i] = type is VectorType; + if (!_isInputVector[i]) // Temporary pending fix of issue #1542. In its current state, the below code would fail anyway with a naked exception if this check was not here. + throw _host.Except("Non-vector columns not supported"); + vecType = (VectorType)type; var expectedType = TensorFlowUtils.Tf2MlNetType(_parent.TFInputTypes[i]); if (type.ItemType != expectedType) throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i], expectedType.ToString(), type.ToString()); var originalShape = _parent.TFInputShapes[i]; var shape = originalShape.ToIntArray(); - var colTypeDims = Enumerable.Range(0, type.AsVector.DimCount + 1).Select(d => d == 0 ? 1 : (long)type.AsVector.GetDim(d - 1)).ToArray(); + var colTypeDims = vecType.Dimensions.Prepend(1).Select(dim => (long)dim).ToArray(); if (shape == null) _fullySpecifiedShapes[i] = new TFShape(colTypeDims); - else if (type.AsVector.DimCount == 1) + else if (vecType.Dimensions.Length == 1) { // If the column is one dimension we make sure that the total size of the TF shape matches. // Compute the total size of the known dimensions of the shape. diff --git a/src/Microsoft.ML.Transforms/NAReplaceTransform.cs b/src/Microsoft.ML.Transforms/NAReplaceTransform.cs index b0c9df8406..6e59b928bf 100644 --- a/src/Microsoft.ML.Transforms/NAReplaceTransform.cs +++ b/src/Microsoft.ML.Transforms/NAReplaceTransform.cs @@ -153,8 +153,7 @@ internal static string TestType(ColumnType type) { // Item type must have an NA value that exists and is not equal to its default value. Func func = TestType; - var meth = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(type.ItemType.RawType); - return (string)meth.Invoke(null, new object[] { type.ItemType }); + return Utils.MarshalInvoke(func, type.ItemType.RawType, type.ItemType); } private static string TestType(ColumnType type) @@ -339,7 +338,7 @@ private void GetReplacementValues(IDataView input, ColumnInfo[] columns, out obj case ReplacementKind.Mean: case ReplacementKind.Minimum: case ReplacementKind.Maximum: - if (!type.ItemType.IsNumber && !type.ItemType.IsTimeSpan && !type.ItemType.IsDateTime) + if (!type.ItemType.IsNumber) throw Host.Except("Cannot perform mean imputations on non-numeric '{0}'", type.ItemType); imputationModes[iinfo] = kind; Utils.Add(ref columnsToImpute, iinfo); diff --git a/test/Microsoft.ML.Benchmarks/HashBench.cs b/test/Microsoft.ML.Benchmarks/HashBench.cs index 53e5789a13..74a056ea77 100644 --- a/test/Microsoft.ML.Benchmarks/HashBench.cs +++ b/test/Microsoft.ML.Benchmarks/HashBench.cs @@ -46,7 +46,7 @@ private void InitMap(T val, ColumnType type, int hashBits = 20) var mapper = xf.GetRowToRowMapper(inRow.Schema); mapper.Schema.TryGetColumnIndex("Bar", out int outCol); var outRow = mapper.GetRow(inRow, c => c == outCol, out var _); - if (type.IsVector) + if (type is VectorType) _vecGetter = outRow.GetGetter>(outCol); else _getter = outRow.GetGetter(outCol); @@ -123,7 +123,7 @@ public void HashScalarDouble() [GlobalSetup(Target = nameof(HashScalarKey))] public void SetupHashScalarKey() { - InitMap(6u, new KeyType(DataKind.U4, 0, 100)); + InitMap(6u, new KeyType(typeof(uint), 0, 100)); } [Benchmark] @@ -174,7 +174,7 @@ public void HashVectorDouble() [GlobalSetup(Target = nameof(HashVectorKey))] public void SetupHashVectorKey() { - InitDenseVecMap(new[] { 1u, 2u, 0u, 4u, 5u }, new KeyType(DataKind.U4, 0, 100)); + InitDenseVecMap(new[] { 1u, 2u, 0u, 4u, 5u }, new KeyType(typeof(uint), 0, 100)); } [Benchmark] diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/ColumnTypes.cs b/test/Microsoft.ML.Core.Tests/UnitTests/ColumnTypes.cs index f3cfa8c270..ac0bd8c187 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/ColumnTypes.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/ColumnTypes.cs @@ -17,12 +17,15 @@ public void TestEqualAndGetHashCode() { var dict = new Dictionary(); // add PrimitiveTypes, KeyType & corresponding VectorTypes - PrimitiveType tmp; VectorType tmp1, tmp2; - foreach (var kind in (DataKind[])Enum.GetValues(typeof(DataKind))) + var types = new PrimitiveType[] { NumberType.I1, NumberType.I2, NumberType.I4, NumberType.I8, + NumberType.U1, NumberType.U2, NumberType.U4, NumberType.U8, NumberType.UG, + TextType.Instance, BoolType.Instance, DateTimeType.Instance, DateTimeOffsetType.Instance, TimeSpanType.Instance }; + + foreach (var type in types) { - tmp = PrimitiveType.FromKind(kind); - if(dict.ContainsKey(tmp) && dict[tmp] != tmp.ToString()) + var tmp = type; + if (dict.ContainsKey(tmp) && dict[tmp] != tmp.ToString()) Assert.True(false, dict[tmp] + " and " + tmp.ToString() + " are duplicates."); dict[tmp] = tmp.ToString(); for (int size = 0; size < 5; size++) @@ -41,13 +44,14 @@ public void TestEqualAndGetHashCode() } // KeyType & Vector - if (!KeyType.IsValidDataKind(kind)) + var rawType = tmp.RawType; + if (!KeyType.IsValidDataType(rawType)) continue; for (ulong min = 0; min < 5; min++) { for (var count = 0; count < 5; count++) { - tmp = new KeyType(kind, min, count); + tmp = new KeyType(rawType, min, count); if (dict.ContainsKey(tmp) && dict[tmp] != tmp.ToString()) Assert.True(false, dict[tmp] + " and " + tmp.ToString() + " are duplicates."); dict[tmp] = tmp.ToString(); @@ -66,7 +70,7 @@ public void TestEqualAndGetHashCode() } } } - tmp = new KeyType(kind, min, 0, false); + tmp = new KeyType(rawType, min, 0, false); if (dict.ContainsKey(tmp) && dict[tmp] != tmp.ToString()) Assert.True(false, dict[tmp] + " and " + tmp.ToString() + " are duplicates."); dict[tmp] = tmp.ToString(); diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/CoreBaseTestClass.cs b/test/Microsoft.ML.Core.Tests/UnitTests/CoreBaseTestClass.cs index 34eb063b5b..b166007f0a 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/CoreBaseTestClass.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/CoreBaseTestClass.cs @@ -29,7 +29,9 @@ protected bool EqualTypes(ColumnType type1, ColumnType type2, bool exactTypes) Contracts.AssertValue(type1); Contracts.AssertValue(type2); - return exactTypes ? type1.Equals(type2) : type1.SameSizeAndItemType(type2); + if (type1.Equals(type2)) + return true; + return !exactTypes && type1 is VectorType vt1 && type2 is VectorType vt2 && vt1.ItemType.Equals(vt2.ItemType) && vt1.Size == vt2.Size; } protected Func GetIdComparer(IRow r1, IRow r2, out ValueGetter idGetter) @@ -148,88 +150,93 @@ protected bool CompareVec(in VBuffer v1, in VBuffer v2, int size, Func< } protected Func GetColumnComparer(IRow r1, IRow r2, int col, ColumnType type, bool exactDoubles) { - if (!type.IsVector) + if (type is VectorType vecType) { - switch (type.RawKind) + int size = vecType.Size; + Contracts.Assert(size >= 0); + var result = vecType.ItemType.RawType.TryGetDataKind(out var kind); + Contracts.Assert(result); + + switch (kind) { case DataKind.I1: - return GetComparerOne(r1, r2, col, (x, y) => x == y); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.U1: - return GetComparerOne(r1, r2, col, (x, y) => x == y); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.I2: - return GetComparerOne(r1, r2, col, (x, y) => x == y); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.U2: - return GetComparerOne(r1, r2, col, (x, y) => x == y); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.I4: - return GetComparerOne(r1, r2, col, (x, y) => x == y); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.U4: - return GetComparerOne(r1, r2, col, (x, y) => x == y); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.I8: - return GetComparerOne(r1, r2, col, (x, y) => x == y); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.U8: - return GetComparerOne(r1, r2, col, (x, y) => x == y); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.R4: - return GetComparerOne(r1, r2, col, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y)); + return GetComparerVec(r1, r2, col, size, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y)); case DataKind.R8: if (exactDoubles) - return GetComparerOne(r1, r2, col, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y)); + return GetComparerVec(r1, r2, col, size, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y)); else - return GetComparerOne(r1, r2, col, EqualWithEps); + return GetComparerVec(r1, r2, col, size, EqualWithEps); case DataKind.Text: - return GetComparerOne>(r1, r2, col, (a, b) => a.Span.SequenceEqual(b.Span)); + return GetComparerVec>(r1, r2, col, size, (a, b) => a.Span.SequenceEqual(b.Span)); case DataKind.Bool: - return GetComparerOne(r1, r2, col, (x, y) => x == y); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.TimeSpan: - return GetComparerOne(r1, r2, col, (x, y) => x.Ticks == y.Ticks); + return GetComparerVec(r1, r2, col, size, (x, y) => x.Ticks == y.Ticks); case DataKind.DT: - return GetComparerOne(r1, r2, col, (x, y) => x.Ticks == y.Ticks); + return GetComparerVec(r1, r2, col, size, (x, y) => x.Ticks == y.Ticks); case DataKind.DZ: - return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y)); + return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y)); case DataKind.UG: - return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y)); + return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y)); } } else { - int size = type.VectorSize; - Contracts.Assert(size >= 0); - switch (type.ItemType.RawKind) + var result = type.RawType.TryGetDataKind(out var kind); + Contracts.Assert(result); + switch (kind) { case DataKind.I1: - return GetComparerVec(r1, r2, col, size, (x, y) => x == y); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.U1: - return GetComparerVec(r1, r2, col, size, (x, y) => x == y); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.I2: - return GetComparerVec(r1, r2, col, size, (x, y) => x == y); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.U2: - return GetComparerVec(r1, r2, col, size, (x, y) => x == y); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.I4: - return GetComparerVec(r1, r2, col, size, (x, y) => x == y); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.U4: - return GetComparerVec(r1, r2, col, size, (x, y) => x == y); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.I8: - return GetComparerVec(r1, r2, col, size, (x, y) => x == y); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.U8: - return GetComparerVec(r1, r2, col, size, (x, y) => x == y); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.R4: - return GetComparerVec(r1, r2, col, size, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y)); + return GetComparerOne(r1, r2, col, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y)); case DataKind.R8: if (exactDoubles) - return GetComparerVec(r1, r2, col, size, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y)); + return GetComparerOne(r1, r2, col, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y)); else - return GetComparerVec(r1, r2, col, size, EqualWithEps); + return GetComparerOne(r1, r2, col, EqualWithEps); case DataKind.Text: - return GetComparerVec>(r1, r2, col, size, (a,b) => a.Span.SequenceEqual(b.Span)); + return GetComparerOne>(r1, r2, col, (a, b) => a.Span.SequenceEqual(b.Span)); case DataKind.Bool: - return GetComparerVec(r1, r2, col, size, (x, y) => x == y); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.TimeSpan: - return GetComparerVec(r1, r2, col, size, (x, y) => x.Ticks == y.Ticks); + return GetComparerOne(r1, r2, col, (x, y) => x.Ticks == y.Ticks); case DataKind.DT: - return GetComparerVec(r1, r2, col, size, (x, y) => x.Ticks == y.Ticks); + return GetComparerOne(r1, r2, col, (x, y) => x.Ticks == y.Ticks); case DataKind.DZ: - return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y)); + return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y)); case DataKind.UG: - return GetComparerVec(r1, r2, col, size, (x, y) => x.Equals(y)); + return GetComparerOne(r1, r2, col, (x, y) => x.Equals(y)); } } @@ -318,7 +325,7 @@ protected bool CheckSameValues(IRowCursor curs1, IRowCursor curs2, bool exactTyp var type2 = curs2.Schema.GetColumnType(col); if (!EqualTypes(type1, type2, exactTypes)) { - Fail("Different types"); + Fail($"Different types {type1} and {type2}"); return Failed(); } comps[col] = GetColumnComparer(curs1, curs2, col, type1, exactDoubles); diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs b/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs index 3f1408efd9..385e6cd527 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs @@ -21,10 +21,10 @@ public DataTypesTest(ITestOutputHelper helper) [Fact] public void R4ToSBtoR4() { - var r4ToSB = Conversions.Instance.GetStringConversion(NumberType.FromKind(DataKind.R4)); + var r4ToSB = Conversions.Instance.GetStringConversion(NumberType.R4); var txToR4 = Conversions.Instance.GetStandardConversion< ReadOnlyMemory, float>( - TextType.Instance, NumberType.FromKind(DataKind.R4), out bool identity2); + TextType.Instance, NumberType.R4, out bool identity2); Assert.NotNull(r4ToSB); Assert.NotNull(txToR4); @@ -45,10 +45,10 @@ public void R4ToSBtoR4() [Fact] public void R8ToSBtoR8() { - var r8ToSB = Conversions.Instance.GetStringConversion(NumberType.FromKind(DataKind.R8)); + var r8ToSB = Conversions.Instance.GetStringConversion(NumberType.R8); var txToR8 = Conversions.Instance.GetStandardConversion, double>( - TextType.Instance, NumberType.FromKind(DataKind.R8), out bool identity2); + TextType.Instance, NumberType.R8, out bool identity2); Assert.NotNull(r8ToSB); Assert.NotNull(txToR8); @@ -69,7 +69,7 @@ public void R8ToSBtoR8() [Fact] public void TXToSByte() { - var mapper = GetMapper, sbyte>(); + var mapper = GetMapper, sbyte>(NumberType.I1); Assert.NotNull(mapper); @@ -109,7 +109,7 @@ public void TXToSByte() [Fact] public void TXToShort() { - var mapper = GetMapper, short>(); + var mapper = GetMapper, short>(NumberType.I2); Assert.NotNull(mapper); @@ -149,7 +149,7 @@ public void TXToShort() [Fact] public void TXToInt() { - var mapper = GetMapper, int>(); + var mapper = GetMapper, int>(NumberType.I4); Assert.NotNull(mapper); @@ -189,7 +189,7 @@ public void TXToInt() [Fact] public void TXToLong() { - var mapper = GetMapper, long>(); + var mapper = GetMapper, long>(NumberType.I8); Assert.NotNull(mapper); @@ -226,14 +226,12 @@ public void TXToLong() Assert.Equal(default, dst); } - public ValueMapper GetMapper() + private static ValueMapper GetMapper(ColumnType dstType) { - Assert.True(typeof(TDst).TryGetDataKind(out DataKind dstDataKind)); + Assert.True(typeof(TDst) == dstType.RawType); return Conversions.Instance.GetStandardConversion( - TextType.Instance, NumberType.FromKind(dstDataKind), out bool identity); + TextType.Instance, dstType, out bool identity); } } } - - diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs index af9fedfde3..5a38ecb6d0 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs @@ -504,7 +504,7 @@ public void TestCrossValidationMacroWithMultiClass() b = schema.TryGetColumnIndex("Fold Index", out foldCol); Assert.True(b); var type = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, countCol); - Assert.True(type != null && type.ItemType.IsText && type.VectorSize == 10); + Assert.True(type is VectorType vecType && vecType.ItemType is TextType && vecType.Size == 10); var slotNames = default(VBuffer>); schema.GetMetadata(MetadataUtils.Kinds.SlotNames, countCol, ref slotNames); Assert.True(slotNames.Values.Select((s, i) => ReadOnlyMemoryUtils.EqualsStr(i.ToString(), s)).All(x => x)); @@ -992,7 +992,7 @@ public void TestTensorFlowEntryPoint() var schema = data.Schema; Assert.Equal(3, schema.ColumnCount); Assert.Equal("Softmax", schema.GetColumnName(2)); - Assert.Equal(10, schema.GetColumnType(2).VectorSize); + Assert.Equal(10, (schema.GetColumnType(2) as VectorType)?.Size); } } } diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index 523af789f4..1a5b46af14 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -851,7 +851,7 @@ public void EntryPointPipelineEnsemble() var hasScoreCol = binaryScored.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out int scoreIndex); Assert.True(hasScoreCol, "Data scored with binary ensemble does not have a score column"); var type = binaryScored.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.ScoreColumnKind, scoreIndex); - Assert.True(type != null && type.IsText, "Binary ensemble scored data does not have correct type of metadata."); + Assert.True(type is TextType, "Binary ensemble scored data does not have correct type of metadata."); var kind = default(ReadOnlyMemory); binaryScored.Schema.GetMetadata(MetadataUtils.Kinds.ScoreColumnKind, scoreIndex, ref kind); Assert.True(ReadOnlyMemoryUtils.EqualsStr(MetadataUtils.Const.ScoreColumnKind.BinaryClassification, kind), @@ -860,7 +860,7 @@ public void EntryPointPipelineEnsemble() hasScoreCol = regressionScored.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out scoreIndex); Assert.True(hasScoreCol, "Data scored with regression ensemble does not have a score column"); type = regressionScored.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.ScoreColumnKind, scoreIndex); - Assert.True(type != null && type.IsText, "Regression ensemble scored data does not have correct type of metadata."); + Assert.True(type is TextType, "Regression ensemble scored data does not have correct type of metadata."); regressionScored.Schema.GetMetadata(MetadataUtils.Kinds.ScoreColumnKind, scoreIndex, ref kind); Assert.True(ReadOnlyMemoryUtils.EqualsStr(MetadataUtils.Const.ScoreColumnKind.Regression, kind), $"Regression ensemble scored data column type should be '{MetadataUtils.Const.ScoreColumnKind.Regression}', but is instead '{kind}'"); @@ -868,7 +868,7 @@ public void EntryPointPipelineEnsemble() hasScoreCol = anomalyScored.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out scoreIndex); Assert.True(hasScoreCol, "Data scored with anomaly detection ensemble does not have a score column"); type = anomalyScored.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.ScoreColumnKind, scoreIndex); - Assert.True(type != null && type.IsText, "Anomaly detection ensemble scored data does not have correct type of metadata."); + Assert.True(type is TextType, "Anomaly detection ensemble scored data does not have correct type of metadata."); anomalyScored.Schema.GetMetadata(MetadataUtils.Kinds.ScoreColumnKind, scoreIndex, ref kind); Assert.True(ReadOnlyMemoryUtils.EqualsStr(MetadataUtils.Const.ScoreColumnKind.AnomalyDetection, kind), $"Anomaly detection ensemble scored data column type should be '{MetadataUtils.Const.ScoreColumnKind.AnomalyDetection}', but is instead '{kind}'"); diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestModelLoad.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestModelLoad.cs index 8891b49d22..87f77a9a3f 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestModelLoad.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestModelLoad.cs @@ -46,8 +46,8 @@ public void LoadOldConcatTransformModel() Assert.Equal("Label", result.Schema[0].Name); Assert.Equal("Features", result.Schema[1].Name); Assert.Equal("Features", result.Schema[2].Name); - Assert.Equal(9, result.Schema[1].Type.VectorSize); - Assert.Equal(18, result.Schema[2].Type.VectorSize); + Assert.Equal(9, (result.Schema[1].Type as VectorType)?.Size); + Assert.Equal(18, (result.Schema[2].Type as VectorType)?.Size); } } } diff --git a/test/Microsoft.ML.Predictor.Tests/TestTransposer.cs b/test/Microsoft.ML.Predictor.Tests/TestTransposer.cs index 4ce87f070b..c6cdbf543a 100644 --- a/test/Microsoft.ML.Predictor.Tests/TestTransposer.cs +++ b/test/Microsoft.ML.Predictor.Tests/TestTransposer.cs @@ -24,16 +24,18 @@ private static T[] NaiveTranspose(IDataView view, int col) { var type = view.Schema.GetColumnType(col); int rc = checked((int)DataViewUtils.ComputeRowCount(view)); - Assert.True(type.ItemType.RawType == typeof(T)); - Assert.True(type.ValueCount > 0); - T[] retval = new T[rc * type.ValueCount]; + var vecType = type as VectorType; + var itemType = vecType?.ItemType ?? type; + Assert.Equal(typeof(T), itemType.RawType); + Assert.NotEqual(0, vecType?.Size); + T[] retval = new T[rc * (vecType?.Size ?? 1)]; using (var cursor = view.GetRowCursor(c => c == col)) { - if (type.IsVector) + if (type is VectorType) { var getter = cursor.GetGetter>(col); - VBuffer temp = default(VBuffer); + VBuffer temp = default; int offset = 0; while (cursor.MoveNext()) { @@ -60,21 +62,20 @@ private static T[] NaiveTranspose(IDataView view, int col) private static void TransposeCheckHelper(IDataView view, int viewCol, ITransposeDataView trans) { int col = viewCol; - var type = trans.TransposeSchema.GetSlotType(col); - var colType = trans.Schema.GetColumnType(col); + VectorType type = trans.TransposeSchema.GetSlotType(col); + ColumnType colType = trans.Schema.GetColumnType(col); Assert.Equal(view.Schema.GetColumnName(viewCol), trans.Schema.GetColumnName(col)); - var expectedType = view.Schema.GetColumnType(viewCol); - // Unfortunately can't use equals because column type equality is a simple reference comparison. :P + ColumnType expectedType = view.Schema.GetColumnType(viewCol); Assert.Equal(expectedType, colType); - Assert.Equal(DataViewUtils.ComputeRowCount(view), (long)type.VectorSize); string desc = string.Format("Column {0} named '{1}'", col, trans.Schema.GetColumnName(col)); + Assert.Equal(DataViewUtils.ComputeRowCount(view), (long)type.Size); Assert.True(typeof(T) == type.ItemType.RawType, $"{desc} had wrong type for slot cursor"); - Assert.True(type.IsVector, $"{desc} expected to be vector but is not"); - Assert.True(type.VectorSize > 0, $"{desc} expected to be known sized vector but is not"); - Assert.True(0 != colType.ValueCount, $"{desc} expected to have fixed size, but does not"); - int rc = type.VectorSize; + Assert.True(type.Size > 0, $"{desc} expected to be known sized vector but is not"); + int valueCount = (colType as VectorType)?.Size ?? 1; + Assert.True(0 != valueCount, $"{desc} expected to have fixed size, but does not"); + int rc = type.Size; T[] expectedVals = NaiveTranspose(view, viewCol); - T[] vals = new T[rc * colType.ValueCount]; + T[] vals = new T[rc * valueCount]; Contracts.Assert(vals.Length == expectedVals.Length); using (var cursor = trans.GetSlotCursor(col)) { @@ -89,7 +90,7 @@ private static void TransposeCheckHelper(IDataView view, int viewCol, ITransp temp.CopyTo(vals, offset); offset += rc; } - Assert.True(colType.ValueCount == offset / rc, $"{desc} slot cursor yielded fewer than expected values"); + Assert.True(valueCount == offset / rc, $"{desc} slot cursor yielded fewer than expected values"); } for (int i = 0; i < vals.Length; ++i) Assert.Equal(expectedVals[i], vals[i]); diff --git a/test/Microsoft.ML.StaticPipelineTesting/ImageAnalyticsTests.cs b/test/Microsoft.ML.StaticPipelineTesting/ImageAnalyticsTests.cs index ea6fee04d3..3b28fc1cfa 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/ImageAnalyticsTests.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/ImageAnalyticsTests.cs @@ -28,13 +28,13 @@ public void SimpleImageSmokeTest() var schema = reader.AsDynamic.GetOutputSchema(); Assert.True(schema.TryGetColumnIndex("Data", out int col), "Could not find 'Data' column"); var type = schema.GetColumnType(col); - Assert.True(type.IsKnownSizeVector, $"Type was supposed to be known size vector but was instead '{type}'"); - var vecType = type.AsVector; + var vecType = type as VectorType; + Assert.True(vecType?.Size > 0, $"Type was supposed to be known size vector but was instead '{type}'"); Assert.Equal(NumberType.R4, vecType.ItemType); - Assert.Equal(3, vecType.DimCount); - Assert.Equal(3, vecType.GetDim(0)); - Assert.Equal(8, vecType.GetDim(1)); - Assert.Equal(10, vecType.GetDim(2)); + Assert.Equal(3, vecType.Dimensions.Length); + Assert.Equal(3, vecType.Dimensions[0]); + Assert.Equal(8, vecType.Dimensions[1]); + Assert.Equal(10, vecType.Dimensions[2]); var readAsImage = TextLoader.CreateReader(env, ctx => ctx.LoadText(0).LoadAsImage()); diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs index 3a3eff6289..de118ca2f5 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs @@ -210,7 +210,7 @@ public void AssertStaticSimple() var schema = SimpleSchemaUtils.Create(env, P("hello", TextType.Instance), P("my", new VectorType(NumberType.I8, 5)), - P("friend", new KeyType(DataKind.U4, 0, 3))); + P("friend", new KeyType(typeof(uint), 0, 3))); var view = new EmptyDataView(env, schema); view.AssertStatic(env, c => new @@ -234,7 +234,7 @@ public void AssertStaticSimpleFailure() var schema = SimpleSchemaUtils.Create(env, P("hello", TextType.Instance), P("my", new VectorType(NumberType.I8, 5)), - P("friend", new KeyType(DataKind.U4, 0, 3))); + P("friend", new KeyType(typeof(uint), 0, 3))); var view = new EmptyDataView(env, schema); Assert.ThrowsAny(() => @@ -269,23 +269,23 @@ public void AssertStaticKeys() var metaValues1 = new VBuffer>(3, new[] { "a".AsMemory(), "b".AsMemory(), "c".AsMemory() }); var meta1 = RowColumnUtils.GetColumn(MetadataUtils.Kinds.KeyValues, new VectorType(TextType.Instance, 3), ref metaValues1); uint value1 = 2; - var col1 = RowColumnUtils.GetColumn("stay", new KeyType(DataKind.U4, 0, 3), ref value1, RowColumnUtils.GetRow(counted, meta1)); + var col1 = RowColumnUtils.GetColumn("stay", new KeyType(typeof(uint), 0, 3), ref value1, RowColumnUtils.GetRow(counted, meta1)); // Next the case where those values are ints. var metaValues2 = new VBuffer(3, new int[] { 1, 2, 3, 4 }); var meta2 = RowColumnUtils.GetColumn(MetadataUtils.Kinds.KeyValues, new VectorType(NumberType.I4, 4), ref metaValues2); var value2 = new VBuffer(2, 0, null, null); - var col2 = RowColumnUtils.GetColumn("awhile", new VectorType(new KeyType(DataKind.U1, 2, 4), 2), ref value2, RowColumnUtils.GetRow(counted, meta2)); + var col2 = RowColumnUtils.GetColumn("awhile", new VectorType(new KeyType(typeof(byte), 2, 4), 2), ref value2, RowColumnUtils.GetRow(counted, meta2)); // Then the case where a value of that kind exists, but is of not of the right kind, in which case it should not be identified as containing that metadata. var metaValues3 = (float)2; var meta3 = RowColumnUtils.GetColumn(MetadataUtils.Kinds.KeyValues, NumberType.R4, ref metaValues3); var value3 = (ushort)1; - var col3 = RowColumnUtils.GetColumn("and", new KeyType(DataKind.U2, 0, 2), ref value3, RowColumnUtils.GetRow(counted, meta3)); + var col3 = RowColumnUtils.GetColumn("and", new KeyType(typeof(ushort), 0, 2), ref value3, RowColumnUtils.GetRow(counted, meta3)); // Then a final case where metadata of that kind is actaully simply altogether absent. var value4 = new VBuffer(5, 0, null, null); - var col4 = RowColumnUtils.GetColumn("listen", new VectorType(new KeyType(DataKind.U4, 0, 2)), ref value4); + var col4 = RowColumnUtils.GetColumn("listen", new VectorType(new KeyType(typeof(uint), 0, 2)), ref value4); // Finally compose a trivial data view out of all this. var row = RowColumnUtils.GetRow(counted, col1, col2, col3, col4); @@ -456,9 +456,9 @@ public void ToKey() Assert.True(schema.TryGetColumnIndex("valuesKey", out int valuesCol)); Assert.True(schema.TryGetColumnIndex("valuesKeyKey", out int valuesKeyCol)); - Assert.Equal(3, schema.GetColumnType(labelCol).KeyCount); - Assert.True(schema.GetColumnType(valuesCol).ItemType.IsKey); - Assert.True(schema.GetColumnType(valuesKeyCol).ItemType.IsKey); + Assert.Equal(3, (schema.GetColumnType(labelCol) as KeyType)?.Count); + Assert.True(schema.GetColumnType(valuesCol) is VectorType valuesVecType && valuesVecType.ItemType is KeyType); + Assert.True(schema.GetColumnType(valuesKeyCol) is VectorType valuesKeyVecType && valuesKeyVecType.ItemType is KeyType); var labelKeyType = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, labelCol); var valuesKeyType = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, valuesCol); @@ -466,9 +466,9 @@ public void ToKey() Assert.NotNull(labelKeyType); Assert.NotNull(valuesKeyType); Assert.NotNull(valuesKeyKeyType); - Assert.True(labelKeyType.IsVector && labelKeyType.ItemType == TextType.Instance); - Assert.True(valuesKeyType.IsVector && valuesKeyType.ItemType == NumberType.Float); - Assert.True(valuesKeyKeyType.IsVector && valuesKeyKeyType.ItemType == NumberType.Float); + Assert.True(labelKeyType is VectorType labelVecType && labelVecType.ItemType == TextType.Instance); + Assert.True(valuesKeyType is VectorType valuesVecType2 && valuesVecType2.ItemType == NumberType.Float); + Assert.True(valuesKeyKeyType is VectorType valuesKeyVecType2 && valuesKeyVecType2.ItemType == NumberType.Float); // Because they're over exactly the same data, they ought to have the same cardinality and everything. Assert.True(valuesKeyKeyType.Equals(valuesKeyType)); } @@ -501,9 +501,9 @@ public void ConcatWith() for (int i = 0; i < idx.Length; ++i) { var type = schema.GetColumnType(idx[i]); - Assert.True(type.VectorSize > 0, $"Col c{i} had unexpected type {type}"); - types[i] = type.AsVector; - Assert.Equal(expectedLen[i], type.VectorSize); + types[i] = type as VectorType; + Assert.True(types[i]?.Size > 0, $"Col c{i} had unexpected type {type}"); + Assert.Equal(expectedLen[i], types[i].Size); } Assert.Equal(TextType.Instance, types[0].ItemType); Assert.Equal(TextType.Instance, types[1].ItemType); @@ -533,12 +533,11 @@ public void Tokenize() Assert.True(schema.TryGetColumnIndex("tokens", out int tokensCol)); var type = schema.GetColumnType(tokensCol); - Assert.True(type.IsVector && !type.IsKnownSizeVector && type.ItemType.IsText); - + Assert.True(type is VectorType vecType && vecType.Size == 0 && vecType.ItemType == TextType.Instance); Assert.True(schema.TryGetColumnIndex("chars", out int charsCol)); type = schema.GetColumnType(charsCol); - Assert.True(type.IsVector && !type.IsKnownSizeVector && type.ItemType.IsKey); - Assert.True(type.ItemType.AsKey.RawKind == DataKind.U2); + Assert.True(type is VectorType vecType2 && vecType2.Size == 0 && vecType2.ItemType is KeyType + && vecType2.ItemType.RawType == typeof(ushort)); } [Fact] @@ -563,11 +562,11 @@ public void NormalizeTextAndRemoveStopWords() Assert.True(schema.TryGetColumnIndex("words_without_stopwords", out int stopwordsCol)); var type = schema.GetColumnType(stopwordsCol); - Assert.True(type.IsVector && !type.IsKnownSizeVector && type.ItemType.IsText); + Assert.True(type is VectorType vecType && vecType.Size == 0 && vecType.ItemType == TextType.Instance); Assert.True(schema.TryGetColumnIndex("normalized_text", out int normTextCol)); type = schema.GetColumnType(normTextCol); - Assert.True(type.IsText && !type.IsVector); + Assert.Equal(TextType.Instance, type); } [Fact] @@ -592,11 +591,11 @@ public void ConvertToWordBag() Assert.True(schema.TryGetColumnIndex("bagofword", out int bagofwordCol)); var type = schema.GetColumnType(bagofwordCol); - Assert.True(type.IsVector && type.IsKnownSizeVector && type.ItemType.IsNumber); + Assert.True(type is VectorType vecType && vecType.Size > 0&& vecType.ItemType is NumberType); Assert.True(schema.TryGetColumnIndex("bagofhashedword", out int bagofhashedwordCol)); type = schema.GetColumnType(bagofhashedwordCol); - Assert.True(type.IsVector && type.IsKnownSizeVector && type.ItemType.IsNumber); + Assert.True(type is VectorType vecType2 && vecType2.Size > 0 && vecType2.ItemType is NumberType); } [Fact] @@ -621,11 +620,11 @@ public void Ngrams() Assert.True(schema.TryGetColumnIndex("ngrams", out int ngramsCol)); var type = schema.GetColumnType(ngramsCol); - Assert.True(type.IsVector && type.IsKnownSizeVector && type.ItemType.IsNumber); + Assert.True(type is VectorType vecType && vecType.Size > 0 && vecType.ItemType is NumberType); Assert.True(schema.TryGetColumnIndex("ngramshash", out int ngramshashCol)); type = schema.GetColumnType(ngramshashCol); - Assert.True(type.IsVector && type.IsKnownSizeVector && type.ItemType.IsNumber); + Assert.True(type is VectorType vecType2 && vecType2.Size > 0 && vecType2.ItemType is NumberType); } @@ -652,19 +651,19 @@ public void LpGcNormAndWhitening() Assert.True(schema.TryGetColumnIndex("lpnorm", out int lpnormCol)); var type = schema.GetColumnType(lpnormCol); - Assert.True(type.IsVector && type.IsKnownSizeVector && type.ItemType.IsNumber); + Assert.True(type is VectorType vecType && vecType.Size > 0 && vecType.ItemType is NumberType); Assert.True(schema.TryGetColumnIndex("gcnorm", out int gcnormCol)); type = schema.GetColumnType(gcnormCol); - Assert.True(type.IsVector && type.IsKnownSizeVector && type.ItemType.IsNumber); + Assert.True(type is VectorType vecType2 && vecType2.Size > 0 && vecType2.ItemType is NumberType); Assert.True(schema.TryGetColumnIndex("zcawhitened", out int zcawhitenedCol)); type = schema.GetColumnType(zcawhitenedCol); - Assert.True(type.IsVector && type.IsKnownSizeVector && type.ItemType.IsNumber); + Assert.True(type is VectorType vecType3 && vecType3.Size > 0 && vecType3.ItemType is NumberType); Assert.True(schema.TryGetColumnIndex("pcswhitened", out int pcswhitenedCol)); type = schema.GetColumnType(pcswhitenedCol); - Assert.True(type.IsVector && type.IsKnownSizeVector && type.ItemType.IsNumber); + Assert.True(type is VectorType vecType4 && vecType4.Size > 0 && vecType4.ItemType is NumberType); } [Fact] @@ -691,8 +690,8 @@ public void LdaTopicModel() Assert.True(schema.TryGetColumnIndex("topics", out int topicsCol)); var type = schema.GetColumnType(topicsCol); - Assert.True(type.IsVector && type.IsKnownSizeVector && type.ItemType.IsNumber); - } + Assert.True(type is VectorType vecType && vecType.Size > 0 && vecType.ItemType is NumberType); +} [Fact] public void FeatureSelection() @@ -716,11 +715,11 @@ public void FeatureSelection() Assert.True(schema.TryGetColumnIndex("bag_of_words_count", out int bagofwordCountCol)); var type = schema.GetColumnType(bagofwordCountCol); - Assert.True(type.IsVector && type.IsKnownSizeVector && type.ItemType.IsNumber); + Assert.True(type is VectorType vecType && vecType.Size > 0 && vecType.ItemType is NumberType); Assert.True(schema.TryGetColumnIndex("bag_of_words_mi", out int bagofwordMiCol)); type = schema.GetColumnType(bagofwordMiCol); - Assert.True(type.IsVector && type.IsKnownSizeVector && type.ItemType.IsNumber); + Assert.True(type is VectorType vecType2 && vecType2.Size > 0 && vecType2.ItemType is NumberType); } [Fact] @@ -773,7 +772,7 @@ public void PrincipalComponentAnalysis() Assert.True(schema.TryGetColumnIndex("pca", out int pcaCol)); var type = schema.GetColumnType(pcaCol); - Assert.True(type.IsVector && type.IsKnownSizeVector && type.ItemType.IsNumber); + Assert.True(type is VectorType vecType && vecType.Size > 0 && vecType.ItemType is NumberType); } [Fact] @@ -844,20 +843,20 @@ public void TextNormalizeStatic() Assert.True(schema.TryGetColumnIndex("norm", out int norm)); var type = schema.GetColumnType(norm); - Assert.True(!type.IsVector && type.ItemType.IsText); + Assert.True(type is TextType); Assert.True(schema.TryGetColumnIndex("norm_Upper", out int normUpper)); type = schema.GetColumnType(normUpper); - Assert.True(!type.IsVector && type.ItemType.IsText); + Assert.True(type is TextType); Assert.True(schema.TryGetColumnIndex("norm_KeepDiacritics", out int diacritics)); type = schema.GetColumnType(diacritics); - Assert.True(!type.IsVector && type.ItemType.IsText); + Assert.True(type is TextType); Assert.True(schema.TryGetColumnIndex("norm_NoPuctuations", out int punct)); type = schema.GetColumnType(punct); - Assert.True(!type.IsVector && type.ItemType.IsText); + Assert.True(type is TextType); Assert.True(schema.TryGetColumnIndex("norm_NoNumbers", out int numbers)); type = schema.GetColumnType(numbers); - Assert.True(!type.IsVector && type.ItemType.IsText); + Assert.True(type is TextType); } [Fact] @@ -876,8 +875,7 @@ public void TestPcaStatic() Assert.True(schema.TryGetColumnIndex("pca", out int pca)); var type = schema[pca].Type; - Assert.True(type.IsVector && type.ItemType.RawKind == DataKind.R4); - Assert.True(type.VectorSize == 5); + Assert.Equal(new VectorType(NumberType.R4, 5), type); } [Fact] @@ -900,14 +898,13 @@ public void TestConvertStatic() Assert.True(schema.TryGetColumnIndex("floatLabel", out int floatLabel)); var type = schema[floatLabel].Type; - Assert.True(!type.IsVector && type.ItemType.RawKind == DataKind.R4); + Assert.Equal(NumberType.R4, type); Assert.True(schema.TryGetColumnIndex("txtFloat", out int txtFloat)); type = schema[txtFloat].Type; - Assert.True(!type.IsVector && type.ItemType.RawKind == DataKind.R4); + Assert.Equal(NumberType.R4, type); Assert.True(schema.TryGetColumnIndex("num", out int num)); type = schema[num].Type; - Assert.True(type.IsVector && type.ItemType.RawKind == DataKind.R4); - Assert.True(type.VectorSize == 3); + Assert.Equal(new VectorType(NumberType.R4, 3), type); } } } \ No newline at end of file diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index 503ec4c3c7..4cbbafb346 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -194,13 +194,13 @@ public void TensorFlowInputsOutputsSchemaTest() var schema = TensorFlowUtils.GetModelSchema(env, model_location); Assert.Equal(86, schema.ColumnCount); Assert.True(schema.TryGetColumnIndex("Placeholder", out int col)); - var type = schema.GetColumnType(col).AsVector; - Assert.Equal(2, type.DimCount); - Assert.Equal(28, type.GetDim(0)); - Assert.Equal(28, type.GetDim(1)); + var type = (VectorType)schema.GetColumnType(col); + Assert.Equal(2, type.Dimensions.Length); + Assert.Equal(28, type.Dimensions[0]); + Assert.Equal(28, type.Dimensions[1]); var metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.OpType, col); Assert.NotNull(metadataType); - Assert.True(metadataType.IsText); + Assert.True(metadataType is TextType); ReadOnlyMemory opType = default; schema.GetMetadata(TensorFlowUtils.OpType, col, ref opType); Assert.Equal("Placeholder", opType.ToString()); @@ -208,15 +208,11 @@ public void TensorFlowInputsOutputsSchemaTest() Assert.Null(metadataType); Assert.True(schema.TryGetColumnIndex("conv2d/Conv2D/ReadVariableOp", out col)); - type = schema.GetColumnType(col).AsVector; - Assert.Equal(4, type.DimCount); - Assert.Equal(5, type.GetDim(0)); - Assert.Equal(5, type.GetDim(1)); - Assert.Equal(1, type.GetDim(2)); - Assert.Equal(32, type.GetDim(3)); + type = (VectorType)schema.GetColumnType(col); + Assert.Equal(new[] { 5, 5, 1, 32 }, type.Dimensions); metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.OpType, col); Assert.NotNull(metadataType); - Assert.True(metadataType.IsText); + Assert.True(metadataType is TextType); schema.GetMetadata(TensorFlowUtils.OpType, col, ref opType); Assert.Equal("Identity", opType.ToString()); metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.InputOps, col); @@ -227,14 +223,11 @@ public void TensorFlowInputsOutputsSchemaTest() Assert.Equal("conv2d/kernel", inputOps.Values[0].ToString()); Assert.True(schema.TryGetColumnIndex("conv2d/Conv2D", out col)); - type = schema.GetColumnType(col).AsVector; - Assert.Equal(3, type.DimCount); - Assert.Equal(28, type.GetDim(0)); - Assert.Equal(28, type.GetDim(1)); - Assert.Equal(32, type.GetDim(2)); + type = (VectorType)schema.GetColumnType(col); + Assert.Equal(new[] { 28, 28, 32 }, type.Dimensions); metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.OpType, col); Assert.NotNull(metadataType); - Assert.True(metadataType.IsText); + Assert.True(metadataType is TextType); schema.GetMetadata(TensorFlowUtils.OpType, col, ref opType); Assert.Equal("Conv2D", opType.ToString()); metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.InputOps, col); @@ -245,12 +238,11 @@ public void TensorFlowInputsOutputsSchemaTest() Assert.Equal("conv2d/Conv2D/ReadVariableOp", inputOps.Values[1].ToString()); Assert.True(schema.TryGetColumnIndex("Softmax", out col)); - type = schema.GetColumnType(col).AsVector; - Assert.Equal(1, type.DimCount); - Assert.Equal(10, type.GetDim(0)); + type = (VectorType)schema.GetColumnType(col); + Assert.Equal(new[] { 10 }, type.Dimensions); metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.OpType, col); Assert.NotNull(metadataType); - Assert.True(metadataType.IsText); + Assert.True(metadataType is TextType); schema.GetMetadata(TensorFlowUtils.OpType, col, ref opType); Assert.Equal("Softmax", opType.ToString()); metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.InputOps, col); @@ -265,10 +257,8 @@ public void TensorFlowInputsOutputsSchemaTest() for (int i = 0; i < schema.ColumnCount; i++) { Assert.Equal(name.ToString(), schema.GetColumnName(i)); - type = schema.GetColumnType(i).AsVector; - Assert.Equal(2, type.DimCount); - Assert.Equal(2, type.GetDim(0)); - Assert.Equal(2, type.GetDim(1)); + type = (VectorType)schema.GetColumnType(i); + Assert.Equal(new[] { 2, 2 }, type.Dimensions); name++; } } @@ -787,9 +777,9 @@ public void TensorFlowTransformCifar() var tensorFlowModel = TensorFlowUtils.LoadTensorFlowModel(env, model_location); var schema = tensorFlowModel.GetInputSchema(); Assert.True(schema.TryGetColumnIndex("Input", out int column)); - var type = schema.GetColumnType(column).AsVector; - var imageHeight = type.GetDim(0); - var imageWidth = type.GetDim(1); + var type = (VectorType)schema.GetColumnType(column); + var imageHeight = type.Dimensions[0]; + var imageWidth = type.Dimensions[1]; var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); @@ -853,9 +843,9 @@ public void TensorFlowTransformCifarSavedModel() var tensorFlowModel = TensorFlowUtils.LoadTensorFlowModel(env, model_location); var schema = tensorFlowModel.GetInputSchema(); Assert.True(schema.TryGetColumnIndex("Input", out int column)); - var type = schema.GetColumnType(column).AsVector; - var imageHeight = type.GetDim(0); - var imageWidth = type.GetDim(1); + var type = (VectorType)schema.GetColumnType(column); + var imageHeight = type.Dimensions[0]; + var imageWidth = type.Dimensions[1]; var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); diff --git a/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs b/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs index 690bd0dd2f..2f2820692f 100644 --- a/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs +++ b/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs @@ -193,9 +193,9 @@ public void TestTensorFlowStaticWithSchema() var tensorFlowModel = TensorFlowUtils.LoadTensorFlowModel(env, modelLocation); var schema = tensorFlowModel.GetInputSchema(); Assert.True(schema.TryGetColumnIndex("Input", out int column)); - var type = schema.GetColumnType(column).AsVector; - var imageHeight = type.GetDim(0); - var imageWidth = type.GetDim(1); + var type = (VectorType)schema.GetColumnType(column); + var imageHeight = type.Dimensions[0]; + var imageWidth = type.Dimensions[1]; var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); diff --git a/test/Microsoft.ML.Tests/TermEstimatorTests.cs b/test/Microsoft.ML.Tests/TermEstimatorTests.cs index b5f8ddbe7f..c543523972 100644 --- a/test/Microsoft.ML.Tests/TermEstimatorTests.cs +++ b/test/Microsoft.ML.Tests/TermEstimatorTests.cs @@ -146,7 +146,8 @@ void TestMetadataCopy() result.Schema.TryGetColumnIndex("T", out int termIndex); var names1 = default(VBuffer>); var type1 = result.Schema.GetColumnType(termIndex); - int size = type1.ItemType.IsKey ? type1.ItemType.KeyCount : -1; + var itemType1 = (type1 as VectorType)?.ItemType ?? type1; + int size = itemType1 is KeyType keyType ? keyType.Count : -1; result.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, termIndex, ref names1); Assert.True(names1.Count > 0); } diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs index bfa22dce86..cb5a8b2748 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs @@ -194,12 +194,12 @@ public void MatrixFactorizationInMemoryData() // Check if the expected types in the trained model are expected. Assert.True(model.MatrixColumnIndexColumnName == "MatrixColumnIndex"); Assert.True(model.MatrixRowIndexColumnName == "MatrixRowIndex"); - Assert.True(model.MatrixColumnIndexColumnType.IsKey); - Assert.True(model.MatrixRowIndexColumnType.IsKey); - var matColKeyType = model.MatrixColumnIndexColumnType.AsKey; + Assert.True(model.MatrixColumnIndexColumnType is KeyType); + Assert.True(model.MatrixRowIndexColumnType is KeyType); + var matColKeyType = (KeyType)model.MatrixColumnIndexColumnType; Assert.True(matColKeyType.Min == _synthesizedMatrixFirstColumnIndex); Assert.True(matColKeyType.Count == _synthesizedMatrixColumnCount); - var matRowKeyType = model.MatrixRowIndexColumnType.AsKey; + var matRowKeyType = (KeyType)model.MatrixRowIndexColumnType; Assert.True(matRowKeyType.Min == _synthesizedMatrixFirstRowIndex); Assert.True(matRowKeyType.Count == _synthesizedMatrixRowCount); @@ -285,12 +285,12 @@ public void MatrixFactorizationInMemoryDataZeroBaseIndex() // Check if the expected types in the trained model are expected. Assert.True(model.MatrixColumnIndexColumnName == nameof(MatrixElementZeroBased.MatrixColumnIndex)); Assert.True(model.MatrixRowIndexColumnName == nameof(MatrixElementZeroBased.MatrixRowIndex)); - Assert.True(model.MatrixColumnIndexColumnType.IsKey); - Assert.True(model.MatrixRowIndexColumnType.IsKey); - var matColKeyType = model.MatrixColumnIndexColumnType.AsKey; + var matColKeyType = model.MatrixColumnIndexColumnType as KeyType; + Assert.NotNull(matColKeyType); + var matRowKeyType = model.MatrixRowIndexColumnType as KeyType; + Assert.NotNull(matRowKeyType); Assert.True(matColKeyType.Min == 0); Assert.True(matColKeyType.Count == _synthesizedMatrixColumnCount); - var matRowKeyType = model.MatrixRowIndexColumnType.AsKey; Assert.True(matRowKeyType.Min == 0); Assert.True(matRowKeyType.Count == _synthesizedMatrixRowCount); diff --git a/test/Microsoft.ML.Tests/Transformers/ConcatTests.cs b/test/Microsoft.ML.Tests/Transformers/ConcatTests.cs index 9133bdbcf9..efb3833be3 100644 --- a/test/Microsoft.ML.Tests/Transformers/ConcatTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/ConcatTests.cs @@ -55,13 +55,13 @@ ColumnType GetType(Schema schema, string name) ColumnType t; t = GetType(data.Schema, "f1"); - Assert.True(t.IsVector && t.ItemType == NumberType.R4 && t.VectorSize == 1); + Assert.True(t is VectorType vt1 && vt1.ItemType == NumberType.R4 && vt1.Size == 1); t = GetType(data.Schema, "f2"); - Assert.True(t.IsVector && t.ItemType == NumberType.R4 && t.VectorSize == 2); + Assert.True(t is VectorType vt2 && vt2.ItemType == NumberType.R4 && vt2.Size == 2); t = GetType(data.Schema, "f3"); - Assert.True(t.IsVector && t.ItemType == NumberType.R4 && t.VectorSize == 5); + Assert.True(t is VectorType vt3 && vt3.ItemType == NumberType.R4 && vt3.Size == 5); t = GetType(data.Schema, "f4"); - Assert.True(t.IsVector && t.ItemType == NumberType.R4 && t.VectorSize == 0); + Assert.True(t is VectorType vt4 && vt4.ItemType == NumberType.R4 && vt4.Size == 0); data = SelectColumnsTransform.CreateKeep(Env, data, "f1", "f2", "f3", "f4"); @@ -111,9 +111,9 @@ ColumnType GetType(Schema schema, string name) ColumnType t; t = GetType(data.Schema, "f2"); - Assert.True(t.IsVector && t.ItemType == NumberType.R4 && t.VectorSize == 2); + Assert.True(t is VectorType vt2 && vt2.ItemType == NumberType.R4 && vt2.Size == 2); t = GetType(data.Schema, "f3"); - Assert.True(t.IsVector && t.ItemType == NumberType.R4 && t.VectorSize == 5); + Assert.True(t is VectorType vt3 && vt3.ItemType == NumberType.R4 && vt3.Size == 5); data = SelectColumnsTransform.CreateKeep(Env, data, "f2", "f3"); diff --git a/test/Microsoft.ML.Tests/Transformers/CopyColumnEstimatorTests.cs b/test/Microsoft.ML.Tests/Transformers/CopyColumnEstimatorTests.cs index 019103ff44..5e63e694ba 100644 --- a/test/Microsoft.ML.Tests/Transformers/CopyColumnEstimatorTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/CopyColumnEstimatorTests.cs @@ -152,7 +152,8 @@ void TestMetadataCopy() var names1 = default(VBuffer>); var names2 = default(VBuffer>); var type1 = result.Schema.GetColumnType(termIndex); - int size = type1.ItemType.IsKey ? type1.ItemType.KeyCount : -1; + var itemType1 = (type1 as VectorType)?.ItemType ?? type1; + int size = (itemType1 as KeyType)?.Count ?? -1; var type2 = result.Schema.GetColumnType(copyIndex); result.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, termIndex, ref names1); result.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, copyIndex, ref names2); diff --git a/test/Microsoft.ML.Tests/Transformers/HashTests.cs b/test/Microsoft.ML.Tests/Transformers/HashTests.cs index e2c2a0159f..a8812b1337 100644 --- a/test/Microsoft.ML.Tests/Transformers/HashTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/HashTests.cs @@ -245,20 +245,20 @@ private void HashTestPositiveIntegerCore(ulong value, uint expected, uint expect if (value <= byte.MaxValue) { HashTestCore((byte)value, NumberType.U1, expected, expectedOrdered, expectedOrdered3); - HashTestCore((byte)value, new KeyType(DataKind.U1, 0, byte.MaxValue - 1), eKey, eoKey, e3Key); + HashTestCore((byte)value, new KeyType(typeof(byte), 0, byte.MaxValue - 1), eKey, eoKey, e3Key); } if (value <= ushort.MaxValue) { HashTestCore((ushort)value, NumberType.U2, expected, expectedOrdered, expectedOrdered3); - HashTestCore((ushort)value, new KeyType(DataKind.U2, 0, ushort.MaxValue - 1), eKey, eoKey, e3Key); + HashTestCore((ushort)value, new KeyType(typeof(ushort), 0, ushort.MaxValue - 1), eKey, eoKey, e3Key); } if (value <= uint.MaxValue) { HashTestCore((uint)value, NumberType.U4, expected, expectedOrdered, expectedOrdered3); - HashTestCore((uint)value, new KeyType(DataKind.U4, 0, int.MaxValue - 1), eKey, eoKey, e3Key); + HashTestCore((uint)value, new KeyType(typeof(uint), 0, int.MaxValue - 1), eKey, eoKey, e3Key); } HashTestCore(value, NumberType.U8, expected, expectedOrdered, expectedOrdered3); - HashTestCore((ulong)value, new KeyType(DataKind.U8, 0, 0), eKey, eoKey, e3Key); + HashTestCore((ulong)value, new KeyType(typeof(ulong), 0, 0), eKey, eoKey, e3Key); HashTestCore(new UInt128(value, 0), NumberType.UG, expected, expectedOrdered, expectedOrdered3); diff --git a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs index 0c909ef97b..b436c96149 100644 --- a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs @@ -271,7 +271,7 @@ public void LdaWorkout() using (var fs = File.Create(outputPath)) DataSaverUtils.SaveDataView(ch, saver, savedData, fs, keepHidden: true); - Assert.Equal(10, savedData.Schema.GetColumnType(0).VectorSize); + Assert.Equal(10, (savedData.Schema.GetColumnType(0) as VectorType)?.Size); } // Diabling this check due to the following issue with consitency of output.