diff --git a/src/Microsoft.ML.Core/Data/ColumnType.cs b/src/Microsoft.ML.Core/Data/ColumnType.cs index 83b5caa27d..628d955907 100644 --- a/src/Microsoft.ML.Core/Data/ColumnType.cs +++ b/src/Microsoft.ML.Core/Data/ColumnType.cs @@ -80,12 +80,6 @@ private protected ColumnType(Type rawType, DataKind rawKind) [BestFriend] internal bool IsPrimitive { get; } - /// - /// Equivalent to as . - /// - [BestFriend] - internal PrimitiveType AsPrimitive => IsPrimitive ? (PrimitiveType)this : null; - /// /// Whether this type is a standard numeric type. External code should use is . /// @@ -140,12 +134,6 @@ internal bool IsBool [BestFriend] internal bool IsKey { get; } - /// - /// Equivalent to as . - /// - [BestFriend] - internal KeyType AsKey => IsKey ? (KeyType)this : null; - /// /// Zero return means either it's not a key type or the cardinality is unknown. External code should first /// test whether this is of type , then if so get the property @@ -166,12 +154,6 @@ internal bool IsBool [BestFriend] internal bool IsVector { get; } - /// - /// Equivalent to as . - /// - [BestFriend] - internal VectorType AsVector => IsVector ? (VectorType)this : null; - /// /// For non-vector types, this returns the column type itself (i.e., return this). /// diff --git a/src/Microsoft.ML.Core/Data/MetadataUtils.cs b/src/Microsoft.ML.Core/Data/MetadataUtils.cs index 024953f173..723c38fc74 100644 --- a/src/Microsoft.ML.Core/Data/MetadataUtils.cs +++ b/src/Microsoft.ML.Core/Data/MetadataUtils.cs @@ -165,12 +165,12 @@ public static VectorType GetCategoricalType(int rangeCount) return new VectorType(NumberType.I4, rangeCount, 2); } - private static volatile ColumnType _scoreColumnSetIdType; + private static volatile KeyType _scoreColumnSetIdType; /// /// The type of the ScoreColumnSetId metadata. /// - public static ColumnType ScoreColumnSetIdType + public static KeyType ScoreColumnSetIdType { get { diff --git a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs index 868dd83b19..a6541c8b8b 100644 --- a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs +++ b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs @@ -210,7 +210,7 @@ private void RunCore(IChannel ch) private bool ShouldAddColumn(Schema schema, int i, uint scoreSet, bool outputNamesAndLabels) { uint scoreSetId = 0; - if (schema.TryGetMetadata(MetadataUtils.ScoreColumnSetIdType.AsPrimitive, MetadataUtils.Kinds.ScoreColumnSetId, i, ref scoreSetId) + if (schema.TryGetMetadata(MetadataUtils.ScoreColumnSetIdType, MetadataUtils.Kinds.ScoreColumnSetId, i, ref scoreSetId) && scoreSetId == scoreSet) { return true; diff --git a/src/Microsoft.ML.Data/Data/Conversion.cs b/src/Microsoft.ML.Data/Data/Conversion.cs index 1cfe187f24..e80135d6c8 100644 --- a/src/Microsoft.ML.Data/Data/Conversion.cs +++ b/src/Microsoft.ML.Data/Data/Conversion.cs @@ -412,15 +412,12 @@ public bool TryGetStandardConversion(ColumnType typeSrc, ColumnType typeDst, conv = null; identity = false; - if (typeSrc.IsKey) + if (typeSrc is KeyType keySrc) { - var keySrc = typeSrc.AsKey; - // Key types are only convertable to compatible key types or unsigned integer // types that are large enough. - if (typeDst.IsKey) + if (typeDst is KeyType keyDst) { - var keyDst = typeDst.AsKey; // We allow the Min value to shift. We currently don't allow the counts to vary. // REVIEW: Should we allow the counts to vary? Allowing the dst to be bigger is trivial. // Smaller dst means mapping values to NA. @@ -451,11 +448,11 @@ public bool TryGetStandardConversion(ColumnType typeSrc, ColumnType typeDst, // REVIEW: Should we look for illegal values and force them to zero? If so, then // we'll need to set identity to false. } - else if (typeDst.IsKey) + else if (typeDst is KeyType keyDst) { if (!typeSrc.IsText) return false; - conv = GetKeyParse(typeDst.AsKey); + conv = GetKeyParse(keyDst); return true; } else if (!typeDst.IsStandardScalar) @@ -490,10 +487,10 @@ public bool TryGetStringConversion(ColumnType type, out ValueMapper(type.AsKey); + conv = GetKeyStringConversion(keyType); return true; } return TryGetStringConversion(out conv); @@ -572,8 +569,8 @@ public TryParseMapper GetTryParseConversion(ColumnType typeDst) "Parse conversion only supported for standard types"); Contracts.Check(typeDst.RawType == typeof(TDst), "Wrong TDst type parameter"); - if (typeDst.IsKey) - return GetKeyTryParse(typeDst.AsKey); + if (typeDst is KeyType keyType) + return GetKeyTryParse(keyType); Contracts.Assert(_tryParseDelegates.ContainsKey(typeDst.RawKind)); return (TryParseMapper)_tryParseDelegates[typeDst.RawKind]; diff --git a/src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs b/src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs index c67c6ae733..604fe3c2b7 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs @@ -55,11 +55,11 @@ private static ColumnType MakeColumnType(SchemaShape.Column inputCol) { ColumnType curType = inputCol.ItemType; if (inputCol.IsKey) - curType = new KeyType(curType.AsPrimitive.RawKind, 0, AllKeySizes); + curType = new KeyType(((PrimitiveType)curType).RawKind, 0, AllKeySizes); if (inputCol.Kind == SchemaShape.Column.VectorKind.VariableVector) - curType = new VectorType(curType.AsPrimitive, 0); + curType = new VectorType((PrimitiveType)curType, 0); else if (inputCol.Kind == SchemaShape.Column.VectorKind.Vector) - curType = new VectorType(curType.AsPrimitive, AllVectorSizes); + curType = new VectorType((PrimitiveType)curType, AllVectorSizes); return curType; } diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs index a3d5260531..7e8a90b75e 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs @@ -840,9 +840,8 @@ public void Save(ModelSaveContext ctx) Contracts.Assert((DataKind)(byte)type.RawKind == type.RawKind); ctx.Writer.Write((byte)type.RawKind); ctx.Writer.WriteBoolByte(type.IsKey); - if (type.IsKey) + if (type is KeyType key) { - var key = type.AsKey; ctx.Writer.WriteBoolByte(key.Contiguous); ctx.Writer.Write(key.Min); ctx.Writer.Write(key.Count); diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs index cdcb507f51..f9d6cd2b09 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs @@ -663,12 +663,14 @@ public Parser(TextLoader parent) { var info = _infos[i]; - if (info.ColType.ItemType.IsKey) + if (info.ColType is KeyType keyType) { - if (!info.ColType.IsVector) - _creator[i] = cache.GetCreatorOne(info.ColType.AsKey); - else - _creator[i] = cache.GetCreatorVec(info.ColType.ItemType.AsKey); + _creator[i] = cache.GetCreatorOne(keyType); + continue; + } + else if (info.ColType is VectorType vectorType && vectorType.ItemType is KeyType vectorKeyType) + { + _creator[i] = cache.GetCreatorVec(vectorKeyType); continue; } diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs index 03fd7dfa70..08bc2c3e3f 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs @@ -489,9 +489,8 @@ private TextLoader.Column GetColumn(string name, ColumnType type, int? start) { DataKind? kind; KeyRange keyRange = null; - if (type.ItemType.IsKey) + if (type.ItemType is KeyType key) { - var key = type.ItemType.AsKey; if (!key.Contiguous) keyRange = new KeyRange(key.Min, contiguous: false); else if (key.Count == 0) diff --git a/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs index fe63426c61..83617f0f52 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs @@ -659,7 +659,7 @@ public VectorType GetSlotType(int col) var view = _parent._entries[col].GetViewOrNull(); if (view == null) return null; - return view.Schema.GetColumnType(0).AsVector; + return view.Schema.GetColumnType(0) as VectorType; } } diff --git a/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeSaver.cs b/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeSaver.cs index 9ce4ab3872..6878fbf935 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeSaver.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeSaver.cs @@ -70,7 +70,7 @@ public bool IsColumnSavable(ColumnType type) // an artificial vector type out of this. Obviously if you can't make a vector // out of the items, then you could not save each slot's values. var itemType = type.ItemType; - var primitiveType = itemType.AsPrimitive; + var primitiveType = itemType as PrimitiveType; if (primitiveType == null) return false; var vectorType = new VectorType(primitiveType, size: 2); diff --git a/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs b/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs index 756876dd97..62d28e1809 100644 --- a/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs @@ -208,12 +208,12 @@ private Delegate CreateGetter(ColumnType colType, InternalSchemaDefinition.Colum else Host.Assert(colType.RawType == outputType); - if (!colType.IsKey) + if (!(colType is KeyType keyType)) del = CreateDirectGetterDelegate; else { var keyRawType = colType.RawType; - Host.Assert(colType.AsKey.Contiguous); + Host.Assert(keyType.Contiguous); Func delForKey = CreateKeyGetterDelegate; return Utils.MarshalInvoke(delForKey, keyRawType, peek, colType); } @@ -299,9 +299,10 @@ private Delegate CreateDirectGetterDelegate(Delegate peekDel) private Delegate CreateKeyGetterDelegate(Delegate peekDel, ColumnType colType) { // Make sure the function is dealing with key. - Host.Check(colType.IsKey); + KeyType keyType = colType as KeyType; + Host.Check(keyType != null); // Following equations work only with contiguous key type. - Host.Check(colType.AsKey.Contiguous); + Host.Check(keyType.Contiguous); // Following equations work only with unsigned integers. Host.Check(typeof(TDst) == typeof(ulong) || typeof(TDst) == typeof(uint) || typeof(TDst) == typeof(byte) || typeof(TDst) == typeof(bool)); @@ -312,8 +313,8 @@ private Delegate CreateKeyGetterDelegate(Delegate peekDel, ColumnType colT TDst rawKeyValue = default; ulong key = 0; // the raw key value as ulong - ulong min = colType.AsKey.Min; - ulong max = min + (ulong)colType.AsKey.Count - 1; + ulong min = keyType.Min; + ulong max = min + (ulong)keyType.Count - 1; ulong result = 0; // the result as ulong ValueGetter getter = (ref TDst dst) => { diff --git a/src/Microsoft.ML.Data/DataView/Transposer.cs b/src/Microsoft.ML.Data/DataView/Transposer.cs index 901c76275a..cbd5dff214 100644 --- a/src/Microsoft.ML.Data/DataView/Transposer.cs +++ b/src/Microsoft.ML.Data/DataView/Transposer.cs @@ -306,8 +306,9 @@ public SchemaImpl(Transposer parent) { ColumnInfo srcInfo = _parent._cols[c]; var ctype = srcInfo.Type.ItemType; - _ectx.Assert(ctype.IsPrimitive); - _slotTypes[c] = new VectorType(ctype.AsPrimitive, _parent.RowCount); + var primitiveType = ctype as PrimitiveType; + _ectx.Assert(primitiveType != null); + _slotTypes[c] = new VectorType(primitiveType, _parent.RowCount); } AsSchema = Schema.Create(this); @@ -1189,7 +1190,7 @@ private sealed class ColumnSplitter : Splitter public ColumnSplitter(IDataView view, int col, int[] lims) : base(view, col) { - var type = _view.Schema.GetColumnType(SrcCol).AsVector; + var type = _view.Schema.GetColumnType(SrcCol) as VectorType; // Only valid use is for two or more slices. Contracts.Assert(Utils.Size(lims) >= 2); Contracts.AssertValue(type); diff --git a/src/Microsoft.ML.Data/EntryPoints/ScoreColumnSelector.cs b/src/Microsoft.ML.Data/EntryPoints/ScoreColumnSelector.cs index b2d2eb8450..f68fcff00a 100644 --- a/src/Microsoft.ML.Data/EntryPoints/ScoreColumnSelector.cs +++ b/src/Microsoft.ML.Data/EntryPoints/ScoreColumnSelector.cs @@ -44,7 +44,7 @@ public static CommonOutputs.TransformOutput SelectColumns(IHostEnvironment env, private static bool ShouldAddColumn(Schema schema, int i, string[] extraColumns, uint scoreSet) { uint scoreSetId = 0; - if (schema.TryGetMetadata(MetadataUtils.ScoreColumnSetIdType.AsPrimitive, MetadataUtils.Kinds.ScoreColumnSetId, i, ref scoreSetId) + if (schema.TryGetMetadata(MetadataUtils.ScoreColumnSetIdType, MetadataUtils.Kinds.ScoreColumnSetId, i, ref scoreSetId) && scoreSetId == scoreSet) { return true; diff --git a/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs b/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs index 0eb63d3f0c..22703accc6 100644 --- a/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs +++ b/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs @@ -723,7 +723,7 @@ public static void ReconcileVectorKeyValues(IHostEnvironment env, IDataView[] vi (ref VBuffer> dst) => schema.GetMetadata(MetadataUtils.Kinds.SlotNames, index, ref dst); } views[i] = LambdaColumnMapper.Create(env, "ReconcileKeyValues", views[i], columnName, columnName, - type, new VectorType(keyType, type.AsVector), mapper, keyValueGetter, slotNamesGetter); + type, new VectorType(keyType, type as VectorType), mapper, keyValueGetter, slotNamesGetter); } } @@ -974,7 +974,7 @@ private static bool VerifyVectorColumnsMatch(int cachedSize, int col, IDataView private static IDataView AddVarLengthColumn(IHostEnvironment env, IDataView idv, string variableSizeVectorColumnName, ColumnType typeSrc) { return LambdaColumnMapper.Create(env, "ChangeToVarLength", idv, variableSizeVectorColumnName, - variableSizeVectorColumnName + "_VarLength", typeSrc, new VectorType(typeSrc.ItemType.AsPrimitive), + variableSizeVectorColumnName + "_VarLength", typeSrc, new VectorType((PrimitiveType)typeSrc.ItemType), (in VBuffer src, ref VBuffer dst) => src.CopyTo(ref dst)); } diff --git a/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs index 12aaa47ad0..916f8ddc72 100644 --- a/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs @@ -1001,11 +1001,11 @@ protected override IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMa if (!perInst.Schema.TryGetColumnIndex(schema.Label.Name, out int labelCol)) throw Host.Except("Could not find column '{0}'", schema.Label.Name); var labelType = perInst.Schema.GetColumnType(labelCol); - if (labelType.IsKey && (!perInst.Schema.HasKeyValues(labelCol, labelType.KeyCount) || labelType.RawKind != DataKind.U4)) + if (labelType is KeyType keyType && (!perInst.Schema.HasKeyValues(labelCol, keyType.KeyCount) || labelType.RawKind != DataKind.U4)) { perInst = LambdaColumnMapper.Create(Host, "ConvertToDouble", perInst, schema.Label.Name, schema.Label.Name, perInst.Schema.GetColumnType(labelCol), NumberType.R8, - (in uint src, ref double dst) => dst = src == 0 ? double.NaN : src - 1 + (double)labelType.AsKey.Min); + (in uint src, ref double dst) => dst = src == 0 ? double.NaN : src - 1 + (double)keyType.Min); } var perInstSchema = perInst.Schema; diff --git a/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs index fa2de9c726..f33fcaee1a 100644 --- a/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs @@ -549,7 +549,7 @@ private void CheckInputColumnTypes(Schema schema, out ColumnType labelType, out var t = schema.GetColumnType(LabelIndex); if (!t.IsKnownSizeVector || (t.ItemType != NumberType.R4 && t.ItemType != NumberType.R8)) throw Host.Except("Label column '{0}' has type '{1}' but must be a known-size vector of R4 or R8", LabelCol, t); - labelType = new VectorType(t.ItemType.AsPrimitive, t.VectorSize); + labelType = new VectorType((PrimitiveType)t.ItemType, t.VectorSize); var slotNamesType = new VectorType(TextType.Instance, t.VectorSize); var builder = new MetadataBuilder(); builder.AddSlotNames(t.VectorSize, CreateSlotNamesGetter(schema, LabelIndex, labelType.VectorSize, "True")); @@ -558,7 +558,7 @@ private void CheckInputColumnTypes(Schema schema, out ColumnType labelType, out t = schema.GetColumnType(ScoreIndex); if (t.VectorSize == 0 || t.ItemType != NumberType.Float) throw Host.Except("Score column '{0}' has type '{1}' but must be a known length vector of type R4", ScoreCol, t); - scoreType = new VectorType(t.ItemType.AsPrimitive, t.VectorSize); + scoreType = new VectorType((PrimitiveType)t.ItemType, t.VectorSize); builder = new MetadataBuilder(); builder.AddSlotNames(t.VectorSize, CreateSlotNamesGetter(schema, ScoreIndex, scoreType.VectorSize, "Predicted")); diff --git a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs index 84acb75898..4e25546ae5 100644 --- a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs @@ -123,7 +123,7 @@ private static ISchemaBoundMapper WrapCore(IHostEnvironment env, ISchemaBound trainSchema.Label.Index, ref value); }; - return MultiClassClassifierScorer.LabelNameBindableMapper.CreateBound(env, (ISchemaBoundRowMapper)mapper, type.AsVector, getter, MetadataUtils.Kinds.TrainingLabelValues, CanWrap); + return MultiClassClassifierScorer.LabelNameBindableMapper.CreateBound(env, (ISchemaBoundRowMapper)mapper, type as VectorType, getter, MetadataUtils.Kinds.TrainingLabelValues, CanWrap); } public BinaryClassifierScorer(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema) diff --git a/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculationTransform.cs b/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculationTransform.cs index 70638e35a9..28cc55f5e5 100644 --- a/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculationTransform.cs +++ b/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculationTransform.cs @@ -365,7 +365,7 @@ public RowMapper(IHostEnvironment env, BindableMapper parent, RoleMappedSchema s else { _outputSchema = Schema.Create(new FeatureContributionSchema(_env, DefaultColumnNames.FeatureContributions, - new VectorType(NumberType.R4, schema.Feature.Type.AsVector), + new VectorType(NumberType.R4, schema.Feature.Type as VectorType), InputSchema, InputRoleMappedSchema.Feature.Index)); } diff --git a/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs index 4b190636f9..d5b7ff9669 100644 --- a/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs @@ -130,9 +130,9 @@ private LabelNameBindableMapper(IHost host, ModelLoadContext ctx) ColumnType type; object value; _host.CheckDecode(saver.TryLoadTypeAndValue(ctx.Reader.BaseStream, out type, out value)); - _host.CheckDecode(type.IsVector); + _type = type as VectorType; + _host.CheckDecode(_type != null); _host.CheckDecode(value != null); - _type = type.AsVector; _getter = Utils.MarshalInvoke(DecodeInit, _type.ItemType.RawType, value); _metadataKind = ctx.Header.ModelVerReadable >= VersionAddedMetadataKind ? ctx.LoadNonEmptyString() : MetadataUtils.Kinds.SlotNames; @@ -493,7 +493,7 @@ public static ISchemaBoundMapper WrapCore(IHostEnvironment env, ISchemaBoundM trainSchema.Label.Index, ref value); }; - return LabelNameBindableMapper.CreateBound(env, (ISchemaBoundRowMapper)mapper, type.AsVector, getter, MetadataUtils.Kinds.SlotNames, CanWrap); + return LabelNameBindableMapper.CreateBound(env, (ISchemaBoundRowMapper)mapper, type as VectorType, getter, MetadataUtils.Kinds.SlotNames, CanWrap); } public MultiClassClassifierScorer(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema) diff --git a/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs b/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs index 73d37120c8..0d94802771 100644 --- a/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs @@ -496,7 +496,7 @@ private BoundColumn MakeColumn(Schema inputSchema, int iinfo) hasSlotNames = false; } - return new BoundColumn(InputSchema, _parent._columns[iinfo], sources, new VectorType(itemType.AsPrimitive, totalSize), + return new BoundColumn(InputSchema, _parent._columns[iinfo], sources, new VectorType((PrimitiveType)itemType, totalSize), isNormalized, hasSlotNames, hasCategoricals, totalSize, catCount); } diff --git a/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs b/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs index 65e5c52dc4..8e0f146b7c 100644 --- a/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs @@ -518,7 +518,7 @@ private void ComputeType(Schema input, int iinfo, SlotDropper slotDropper, Host.Assert(typeSrc.IsKnownSizeVector); var dstLength = slotDropper.DstLength; var hasSlotNames = input.HasSlotNames(_cols[iinfo], _srcTypes[iinfo].VectorSize); - type = new VectorType(typeSrc.ItemType.AsPrimitive, Math.Max(dstLength, 1)); + type = new VectorType((PrimitiveType)typeSrc.ItemType, Math.Max(dstLength, 1)); suppressed = dstLength == 0; } } @@ -822,7 +822,7 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() { var dstLength = _slotDropper[iinfo].DstLength; var hasSlotNames = InputSchema.HasSlotNames(_cols[iinfo], _srcTypes[iinfo].VectorSize); - var type = new VectorType(_srcTypes[iinfo].ItemType.AsPrimitive, Math.Max(dstLength, 1)); + var type = new VectorType((PrimitiveType)_srcTypes[iinfo].ItemType, Math.Max(dstLength, 1)); if (hasSlotNames && dstLength > 0) { diff --git a/src/Microsoft.ML.Data/Transforms/Hashing.cs b/src/Microsoft.ML.Data/Transforms/Hashing.cs index 9ea54f4897..8c40a0138e 100644 --- a/src/Microsoft.ML.Data/Transforms/Hashing.cs +++ b/src/Microsoft.ML.Data/Transforms/Hashing.cs @@ -911,7 +911,7 @@ private void AddMetaKeyValues(int i, MetadataBuilder builder) { _parent._keyValues[i].CopyTo(ref dst); }; - builder.AddKeyValues(_parent._kvTypes[i].VectorSize, _parent._kvTypes[i].ItemType.AsPrimitive, getter); + builder.AddKeyValues(_parent._kvTypes[i].VectorSize, (PrimitiveType)_parent._kvTypes[i].ItemType, getter); } protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) => _parent.GetGetterCore(input, iinfo, out disposer); diff --git a/src/Microsoft.ML.Data/Transforms/InvertHashUtils.cs b/src/Microsoft.ML.Data/Transforms/InvertHashUtils.cs index 8d041dc6ce..8b683a73bf 100644 --- a/src/Microsoft.ML.Data/Transforms/InvertHashUtils.cs +++ b/src/Microsoft.ML.Data/Transforms/InvertHashUtils.cs @@ -41,12 +41,12 @@ public static ValueMapper GetSimpleMapper(Schema schema, in var conv = Conversion.Conversions.Instance; // First: if not key, then get the standard string converison. - if (!type.IsKey) + if (!(type is KeyType keyType)) return conv.GetStringConversion(type); bool identity; // Second choice: if key, utilize the KeyValues metadata for that key, if it has one and is text. - if (schema.HasKeyValues(col, type.KeyCount)) + if (schema.HasKeyValues(col, keyType.KeyCount)) { // REVIEW: Non-textual KeyValues are certainly possible. Should we handle them? // Get the key names. @@ -70,7 +70,7 @@ public static ValueMapper GetSimpleMapper(Schema schema, in } // Third choice: just use the key value itself, subject to offsetting by the min. - return conv.GetKeyStringConversion(type.AsKey); + return conv.GetKeyStringConversion(keyType); } public static ValueMapper, StringBuilder> GetPairMapper(ValueMapper submap) diff --git a/src/Microsoft.ML.Data/Transforms/KeyToValue.cs b/src/Microsoft.ML.Data/Transforms/KeyToValue.cs index f2b4277059..b5de375360 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToValue.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToValue.cs @@ -232,10 +232,10 @@ private void ComputeKvMaps(Schema schema, out ColumnType[] types, out KeyToValue var typeVals = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, ColMapNewToOld[iinfo]); Host.Check(typeVals != null, "Metadata KeyValues does not exist"); Host.Check(typeVals.VectorSize == typeSrc.ItemType.KeyCount, "KeyValues metadata size does not match column type key count"); - if (!typeSrc.IsVector) + if (!(typeSrc is VectorType vectorType)) types[iinfo] = typeVals.ItemType; else - types[iinfo] = new VectorType(typeVals.ItemType.AsPrimitive, typeSrc.AsVector); + types[iinfo] = new VectorType((PrimitiveType)typeVals.ItemType, vectorType); // MarshalInvoke with two generic params. Func func = GetKeyMetadata; @@ -258,7 +258,7 @@ private KeyToValueMap GetKeyMetadata(int iinfo, ColumnType typeKey Host.Check(keyMetadata.Length == typeKey.ItemType.KeyCount); VBufferUtils.Densify(ref keyMetadata); - return new KeyToValueMap(this, typeKey.ItemType.AsKey, typeVal.ItemType.AsPrimitive, keyMetadata, iinfo); + return new KeyToValueMap(this, (KeyType)typeKey.ItemType, (PrimitiveType)typeVal.ItemType, keyMetadata, iinfo); } /// /// A map is an object capable of creating the association from an input type, to an output diff --git a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs index 1c781b3e32..493d313cc2 100644 --- a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs +++ b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs @@ -370,7 +370,7 @@ internal static bool GetNewType(IExceptionContext ectx, ColumnType srcType, Data if (!srcType.ItemType.IsKey && !srcType.ItemType.IsText) return false; } - else if (!srcType.ItemType.IsKey) + else if (!(srcType.ItemType is KeyType key)) itemType = PrimitiveType.FromKind(kind); else if (!KeyType.IsValidDataKind(kind)) { @@ -379,7 +379,6 @@ internal static bool GetNewType(IExceptionContext ectx, ColumnType srcType, Data } else { - var key = srcType.ItemType.AsKey; ectx.Assert(KeyType.IsValidDataKind(key.RawKind)); int count = key.Count; // Technically, it's an error for the counts not to match, but we'll let the Conversions @@ -436,8 +435,8 @@ private static bool CanConvertToType(IExceptionContext ectx, ColumnType srcType, return false; typeDst = itemType; - if (srcType.IsVector) - typeDst = new VectorType(itemType, srcType.AsVector); + if (srcType is VectorType vectorType) + typeDst = new VectorType(itemType, vectorType); return true; } @@ -473,9 +472,9 @@ protected override Delegate MakeGetter(Row input, int iinfo, Func act Contracts.AssertValue(input); Contracts.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); disposer = null; - if (!_types[iinfo].IsVector) + if (!(_types[iinfo] is VectorType vectorType)) return RowCursorUtils.GetGetterAs(_types[iinfo], input, _srcCols[iinfo]); - return RowCursorUtils.GetVecGetterAs(_types[iinfo].AsVector.ItemType, input, _srcCols[iinfo]); + return RowCursorUtils.GetVecGetterAs(vectorType.ItemType, input, _srcCols[iinfo]); } public void SaveAsOnnx(OnnxContext ctx) @@ -507,7 +506,7 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, node.AddAttribute("to", (byte)_parent._columns[iinfo].OutputKind); if (_parent._columns[iinfo].OutputKeyRange != null) { - var key = _types[iinfo].ItemType.AsKey; + var key = (KeyType)_types[iinfo].ItemType; node.AddAttribute("min", key.Min); node.AddAttribute("max", key.Count); node.AddAttribute("contiguous", key.Contiguous); diff --git a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs index cda1d03401..b179f358dc 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs @@ -736,8 +736,8 @@ public Mapper(ValueToKeyMappingTransformer parent, Schema inputSchema) var type = _infos[i].TypeSrc; KeyType keyType = _parent._unboundMaps[i].OutputType; ColumnType colType; - if (type.IsVector) - colType = new VectorType(keyType, type.AsVector); + if (type is VectorType vectorType) + colType = new VectorType(keyType, vectorType); else colType = keyType; _types[i] = colType; diff --git a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformerImpl.cs b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformerImpl.cs index 87c112b87d..8c7961c969 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformerImpl.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformerImpl.cs @@ -49,7 +49,7 @@ public static Builder Create(ColumnType type, SortOrder sortOrder) Contracts.Assert(sortOrder == SortOrder.Occurrence || sortOrder == SortOrder.Value); bool sorted = sortOrder == SortOrder.Value; - PrimitiveType itemType = type.ItemType.AsPrimitive; + PrimitiveType itemType = type.ItemType as PrimitiveType; Contracts.AssertValue(itemType); if (itemType.IsText) return new TextImpl(sorted); @@ -568,7 +568,7 @@ private static TermMap LoadCodecCore(ModelLoadContext ctx, IExceptionContext } } - return new HashArrayImpl(codec.Type.AsPrimitive, values); + return new HashArrayImpl((PrimitiveType)codec.Type, values); } internal abstract void WriteTextTerms(TextWriter writer); @@ -1101,7 +1101,7 @@ private bool AddMetadataCore(ColumnType srcMetaType, MetadataBuilder buil _host.AssertValue(srcMetaType); _host.Assert(srcMetaType.RawType == typeof(TMeta)); _host.AssertValue(builder); - var srcType = TypedMap.ItemType.AsKey; + var srcType = TypedMap.ItemType as KeyType; _host.AssertValue(srcType); var dstType = new KeyType(DataKind.U4, srcType.Min, srcType.Count); var convInst = Runtime.Data.Conversion.Conversions.Instance; @@ -1156,7 +1156,7 @@ private bool AddMetadataCore(ColumnType srcMetaType, MetadataBuilder buil getter(ref dst); _host.Assert(dst.Length == TypedMap.OutputType.KeyCount); }; - builder.AddKeyValues(TypedMap.OutputType.KeyCount, srcMetaType.ItemType.AsPrimitive, mgetter); + builder.AddKeyValues(TypedMap.OutputType.KeyCount, (PrimitiveType)srcMetaType.ItemType, mgetter); } return true; } @@ -1169,7 +1169,7 @@ public override void WriteTextTerms(TextWriter writer) _schema.TryGetColumnIndex(_infos[_iinfo].Source, out int srcCol); ColumnType srcMetaType = _schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, srcCol); if (srcMetaType == null || srcMetaType.VectorSize != TypedMap.ItemType.KeyCount || - TypedMap.ItemType.KeyCount == 0 || !Utils.MarshalInvoke(WriteTextTermsCore, srcMetaType.ItemType.RawType, srcMetaType.AsVector.ItemType, writer)) + TypedMap.ItemType.KeyCount == 0 || !Utils.MarshalInvoke(WriteTextTermsCore, srcMetaType.ItemType.RawType, ((VectorType)srcMetaType).ItemType, writer)) { // No valid input key-value metadata. Back off to the base implementation. base.WriteTextTerms(writer); @@ -1180,7 +1180,7 @@ private bool WriteTextTermsCore(PrimitiveType srcMetaType, TextWriter wri { _host.AssertValue(srcMetaType); _host.Assert(srcMetaType.RawType == typeof(TMeta)); - var srcType = TypedMap.ItemType.AsKey; + var srcType = TypedMap.ItemType as KeyType; _host.AssertValue(srcType); var dstType = new KeyType(DataKind.U4, srcType.Min, srcType.Count); var convInst = Runtime.Data.Conversion.Conversions.Instance; diff --git a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs index dafa19eb61..d393e6c316 100644 --- a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs +++ b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs @@ -603,7 +603,7 @@ protected static int CheckLabelColumn(IHostEnvironment env, IPredictorModel[] mo if (mdType == null || !mdType.IsKnownSizeVector) throw env.Except("Label column of type key must have a vector of key values metadata"); - return Utils.MarshalInvoke(CheckKeyLabelColumnCore, mdType.ItemType.RawType, env, models, labelType.AsKey, schema, labelInfo.Index, mdType); + return Utils.MarshalInvoke(CheckKeyLabelColumnCore, mdType.ItemType.RawType, env, models, (KeyType)labelType, schema, labelInfo.Index, mdType); } // When the label column is not a key, we check that the number of classes is the same for all the predictors, by checking the @@ -655,8 +655,8 @@ private static int CheckKeyLabelColumnCore(IHostEnvironment env, IPredictorMo if (labelInfo == null) throw env.Except("Training schema for model {0} does not have a label column", i); - var curLabelType = rmd.Schema.Schema.GetColumnType(rmd.Schema.Label.Index); - if (!labelType.Equals(curLabelType.AsKey)) + var curLabelType = rmd.Schema.Schema.GetColumnType(rmd.Schema.Label.Index) as KeyType; + if (!labelType.Equals(curLabelType)) throw env.Except("Label column of model {0} has different type than model 0", i); var mdType = rmd.Schema.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, labelInfo.Index); diff --git a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs index 7eef450580..ec3827286c 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs @@ -788,7 +788,7 @@ private static IDataView AppendLabelTransform(IHostEnvironment env, IChannel ch, "labelPermutationSeed != 0 only applies on a multi-class learning problem when the label type is a key."); return input; } - return Utils.MarshalInvoke(AppendFloatMapper, labelType.RawType, env, ch, input, labelName, labelType.AsKey, + return Utils.MarshalInvoke(AppendFloatMapper, labelType.RawType, env, ch, input, labelName, (KeyType)labelType, labelPermutationSeed); } } diff --git a/src/Microsoft.ML.HalLearners/VectorWhitening.cs b/src/Microsoft.ML.HalLearners/VectorWhitening.cs index 2a0358f74a..707e8c780f 100644 --- a/src/Microsoft.ML.HalLearners/VectorWhitening.cs +++ b/src/Microsoft.ML.HalLearners/VectorWhitening.cs @@ -323,7 +323,8 @@ protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol // Check if the input column's type is supported. Note that only float vector with a known shape is allowed. internal static string TestColumn(ColumnType type) { - if ((type.IsVector && !type.IsKnownSizeVector && (type.AsVector.Dimensions.Length > 1)) || type.ItemType != NumberType.R4) + if ((type is VectorType vectorType && !vectorType.IsKnownSizeVector && vectorType.Dimensions.Length > 1) + || type.ItemType != NumberType.R4) return "Expected float or float vector of known size"; if ((long)type.ValueCount * type.ValueCount > Utils.ArrayMaxSize) diff --git a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs index 46300386ab..2216beadb7 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs @@ -139,14 +139,14 @@ protected override void ConvertNaNLabels(IChannel ch, RoleMappedData data, float if (maxLabel >= _maxNumClass) throw ch.ExceptParam(nameof(data), $"max labelColumn cannot exceed {_maxNumClass}"); - if (data.Schema.Label.Type.IsKey) + if (data.Schema.Label.Type is KeyType keyType) { - ch.Check(data.Schema.Label.Type.AsKey.Contiguous, "labelColumn value should be contiguous"); + ch.Check(keyType.Contiguous, "labelColumn value should be contiguous"); if (hasNaNLabel) - _numClass = data.Schema.Label.Type.AsKey.Count + 1; + _numClass = keyType.Count + 1; else - _numClass = data.Schema.Label.Type.AsKey.Count; - _tlcNumClass = data.Schema.Label.Type.AsKey.Count; + _numClass = keyType.Count; + _tlcNumClass = keyType.Count; } else { diff --git a/src/Microsoft.ML.Onnx/OnnxUtils.cs b/src/Microsoft.ML.Onnx/OnnxUtils.cs index 1d6c4b6fc1..877fc7b9ed 100644 --- a/src/Microsoft.ML.Onnx/OnnxUtils.cs +++ b/src/Microsoft.ML.Onnx/OnnxUtils.cs @@ -296,10 +296,10 @@ public static ModelArgs GetModelArgs(ColumnType type, string colName, TensorProto.Types.DataType dataType = TensorProto.Types.DataType.Undefined; DataKind rawKind; - if (type.IsVector) - rawKind = type.AsVector.ItemType.RawKind; - else if (type.IsKey) - rawKind = type.AsKey.RawKind; + if (type is VectorType vectorType) + rawKind = vectorType.ItemType.RawKind; + else if (type is KeyType keyType) + rawKind = keyType.RawKind; else rawKind = type.RawKind; @@ -367,7 +367,7 @@ public static ModelArgs GetModelArgs(ColumnType type, string colName, dimsLocal.Add(1); else if (type.ValueCount > 1) { - var vec = type.AsVector; + var vec = (VectorType)type; for (int i = 0; i < vec.Dimensions.Length; i++) dimsLocal.Add(vec.Dimensions[i]); } diff --git a/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs b/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs index e42c6b002e..d858d5bef4 100644 --- a/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs +++ b/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs @@ -347,7 +347,7 @@ private MatrixFactorizationPredictor TrainCore(IChannel ch, RoleMappedData data, using (var buffer = PrepareBuffer()) { buffer.Train(ch, rowCount, colCount, cursor, labGetter, matrixRowIndexGetter, matrixColumnIndexGetter); - predictor = new MatrixFactorizationPredictor(Host, buffer, matrixColumnIndexColInfo.Type.AsKey, matrixRowIndexColInfo.Type.AsKey); + predictor = new MatrixFactorizationPredictor(Host, buffer, (KeyType)matrixColumnIndexColInfo.Type, (KeyType)matrixRowIndexColInfo.Type); } } else @@ -365,7 +365,7 @@ private MatrixFactorizationPredictor TrainCore(IChannel ch, RoleMappedData data, buffer.TrainWithValidation(ch, rowCount, colCount, cursor, labGetter, matrixRowIndexGetter, matrixColumnIndexGetter, validCursor, validLabelGetter, validMatrixRowIndexGetter, validMatrixColumnIndexGetter); - predictor = new MatrixFactorizationPredictor(Host, buffer, matrixColumnIndexColInfo.Type.AsKey, matrixRowIndexColInfo.Type.AsKey); + predictor = new MatrixFactorizationPredictor(Host, buffer, (KeyType)matrixColumnIndexColInfo.Type, (KeyType)matrixRowIndexColInfo.Type); } } } diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs index be73a2de30..821cb21cb0 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs @@ -299,14 +299,16 @@ private FieldAwareFactorizationMachinePredictor TrainCore(IChannel ch, IProgress for (int f = 0; f < fieldCount; f++) { var col = featureColumns[f]; - Host.Assert(col.Type.AsVector.VectorSize > 0); if (col == null) throw ch.ExceptParam(nameof(data), "Empty feature column not allowed"); Host.Assert(!data.Schema.Schema.IsHidden(col.Index)); - if (!col.Type.IsKnownSizeVector || col.Type.ItemType != NumberType.Float) + if (!(col.Type is VectorType vectorType) || + !vectorType.IsKnownSizeVector || + vectorType.ItemType != NumberType.Float) throw ch.ExceptParam(nameof(data), "Training feature column '{0}' must be a known-size vector of R4, but has type: {1}.", col.Name, col.Type); + Host.Assert(vectorType.VectorSize > 0); fieldColumnIndexes[f] = col.Index; - totalFeatureCount += col.Type.AsVector.VectorSize; + totalFeatureCount += vectorType.VectorSize; } ch.Check(checked(totalFeatureCount * fieldCount * _latentDimAligned) <= Utils.ArrayMaxSize, "Latent dimension or the number of fields too large"); if (predictor != null) diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs index b7a1ba901f..f555545099 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs @@ -837,7 +837,7 @@ public Mapper(TensorFlowTransform parent, Schema inputSchema) : else { if (shape.Select((dim, j) => dim != -1 && dim != colTypeDims[j]).Any(b => b)) - throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {originalShape.ToString()}, but input data is {type.AsVector.ToString()}."); + throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {originalShape.ToString()}, but input data is {vecType.ToString()}."); // Fill in the unknown dimensions. var l = new long[originalShape.NumDimensions]; diff --git a/src/Microsoft.ML.Transforms/GroupTransform.cs b/src/Microsoft.ML.Transforms/GroupTransform.cs index 1f77e1dd16..1b9e341007 100644 --- a/src/Microsoft.ML.Transforms/GroupTransform.cs +++ b/src/Microsoft.ML.Transforms/GroupTransform.cs @@ -287,8 +287,9 @@ private static ColumnType[] BuildColumnTypes(Schema input, int[] ids) for (int i = 0; i < ids.Length; i++) { var srcType = input.GetColumnType(ids[i]); - Contracts.Assert(srcType.IsPrimitive); - types[i] = new VectorType(srcType.AsPrimitive, size: 0); + var primitiveType = srcType as PrimitiveType; + Contracts.Assert(primitiveType != null); + types[i] = new VectorType(primitiveType, size: 0); } return types; } diff --git a/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs index 77e89f8882..cc678facfd 100644 --- a/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs +++ b/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs @@ -164,7 +164,7 @@ public Mapper(MissingValueDroppingTransformer parent, Schema inputSchema) : inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out _srcCols[i]); var srcCol = inputSchema[_srcCols[i]]; _srcTypes[i] = srcCol.Type; - _types[i] = new VectorType(srcCol.Type.ItemType.AsPrimitive); + _types[i] = new VectorType((PrimitiveType)srcCol.Type.ItemType); _isNAs[i] = GetIsNADelegate(srcCol.Type); } } diff --git a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransform.cs b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransform.cs index a58603c16e..09ab2853da 100644 --- a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransform.cs +++ b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransform.cs @@ -140,11 +140,11 @@ private VectorType[] GetTypesAndMetadata() // This ensures that our feature count doesn't overflow. Host.Check(type.ValueCount < int.MaxValue / 2); - if (!type.IsVector) + if (!(type is VectorType vectorType)) types[iinfo] = new VectorType(NumberType.Float, 2); else { - types[iinfo] = new VectorType(NumberType.Float, type.AsVector, 2); + types[iinfo] = new VectorType(NumberType.Float, vectorType, 2); // Produce slot names metadata iff the source has (valid) slot names. ColumnType typeNames; diff --git a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs index 73efd1d719..8e2580accb 100644 --- a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs +++ b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs @@ -182,10 +182,10 @@ private ColInfo[] CreateInfos(Schema inputSchema) _parent.CheckInputColumn(inputSchema, i, colSrc); var inType = inputSchema.GetColumnType(colSrc); ColumnType outType; - if (!inType.IsVector) + if (!(inType is VectorType vectorType)) outType = BoolType.Instance; else - outType = new VectorType(BoolType.Instance, inType.AsVector); + outType = new VectorType(BoolType.Instance, vectorType); infos[i] = new ColInfo(_parent.ColumnPairs[i].input, _parent.ColumnPairs[i].output, inType, outType); } return infos; @@ -469,7 +469,9 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.SlotNames, out var slotMeta)) metadata.Add(slotMeta); metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false)); - ColumnType type = !col.ItemType.IsVector ? (ColumnType)BoolType.Instance : new VectorType(BoolType.Instance, col.ItemType.AsVector); + ColumnType type = !(col.ItemType is VectorType vectorType) ? + (ColumnType)BoolType.Instance : + new VectorType(BoolType.Instance, vectorType); result[colPair.output] = new SchemaShape.Column(colPair.output, col.Kind, type, false, new SchemaShape(metadata.ToArray())); } return new SchemaShape(result.Values); diff --git a/src/Microsoft.ML.Transforms/MissingValueReplacing.cs b/src/Microsoft.ML.Transforms/MissingValueReplacing.cs index 296aed0eab..a6579cb9b2 100644 --- a/src/Microsoft.ML.Transforms/MissingValueReplacing.cs +++ b/src/Microsoft.ML.Transforms/MissingValueReplacing.cs @@ -324,8 +324,8 @@ private void GetReplacementValues(IDataView input, ColumnInfo[] columns, out obj input.Schema.TryGetColumnIndex(columns[iinfo].Input, out int colSrc); sources[iinfo] = colSrc; var type = input.Schema.GetColumnType(colSrc); - if (type.IsVector) - type = new VectorType(type.ItemType.AsPrimitive, type.AsVector); + if (type is VectorType vectorType) + type = new VectorType((PrimitiveType)type.ItemType, vectorType); Delegate isNa = GetIsNADelegate(type); types[iinfo] = type; var kind = (ReplacementKind)columns[iinfo].Replacement; @@ -593,8 +593,8 @@ public Mapper(MissingValueReplacingTransformer parent, Schema inputSchema) for (int i = 0; i < _parent.ColumnPairs.Length; i++) { var type = _infos[i].TypeSrc; - if (type.IsVector) - type = new VectorType(type.ItemType.AsPrimitive, type.AsVector); + if (type is VectorType vectorType) + type = new VectorType((PrimitiveType)type.ItemType, vectorType); var repType = _parent._repIsDefault[i] != null ? _parent._replaceTypes[i] : _parent._replaceTypes[i].ItemType; if (!type.ItemType.Equals(repType.ItemType)) throw Host.ExceptParam(nameof(InputSchema), "Column '{0}' item type '{1}' does not match expected ColumnType of '{2}'", @@ -897,10 +897,10 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src { DataKind rawKind; var type = _infos[iinfo].TypeSrc; - if (type.IsVector) - rawKind = type.AsVector.ItemType.RawKind; - else if (type.IsKey) - rawKind = type.AsKey.RawKind; + if (type is VectorType vectorType) + rawKind = vectorType.ItemType.RawKind; + else if (type is KeyType keyType) + rawKind = keyType.RawKind; else rawKind = type.RawKind; @@ -965,7 +965,9 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) metadata.Add(slotMeta); if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.IsNormalized, out var normalized)) metadata.Add(normalized); - var type = !col.ItemType.IsVector ? col.ItemType : new VectorType(col.ItemType.ItemType.AsPrimitive, col.ItemType.AsVector); + var type = !(col.ItemType is VectorType vectorType) ? + col.ItemType : + new VectorType((PrimitiveType)col.ItemType.ItemType, vectorType); result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, col.Kind, type, false, new SchemaShape(metadata.ToArray())); } return new SchemaShape(result.Values); diff --git a/src/Microsoft.ML.Transforms/UngroupTransform.cs b/src/Microsoft.ML.Transforms/UngroupTransform.cs index e52f6d5354..5354dfda0b 100644 --- a/src/Microsoft.ML.Transforms/UngroupTransform.cs +++ b/src/Microsoft.ML.Transforms/UngroupTransform.cs @@ -291,7 +291,7 @@ private static void CheckAndBind(IExceptionContext ectx, Schema inputSchema, if (!colType.IsVector || !colType.ItemType.IsPrimitive) throw ectx.ExceptUserArg(nameof(Arguments.Column), "Pivot column '{0}' has type '{1}', but must be a vector of primitive types", name, colType); - infos[i] = new PivotColumnInfo(name, col, colType.VectorSize, colType.ItemType.AsPrimitive); + infos[i] = new PivotColumnInfo(name, col, colType.VectorSize, (PrimitiveType)colType.ItemType); } }