diff --git a/src/Microsoft.ML.Core/Data/ColumnType.cs b/src/Microsoft.ML.Core/Data/ColumnType.cs
index 83b5caa27d..628d955907 100644
--- a/src/Microsoft.ML.Core/Data/ColumnType.cs
+++ b/src/Microsoft.ML.Core/Data/ColumnType.cs
@@ -80,12 +80,6 @@ private protected ColumnType(Type rawType, DataKind rawKind)
[BestFriend]
internal bool IsPrimitive { get; }
- ///
- /// Equivalent to as .
- ///
- [BestFriend]
- internal PrimitiveType AsPrimitive => IsPrimitive ? (PrimitiveType)this : null;
-
///
/// Whether this type is a standard numeric type. External code should use is .
///
@@ -140,12 +134,6 @@ internal bool IsBool
[BestFriend]
internal bool IsKey { get; }
- ///
- /// Equivalent to as .
- ///
- [BestFriend]
- internal KeyType AsKey => IsKey ? (KeyType)this : null;
-
///
/// Zero return means either it's not a key type or the cardinality is unknown. External code should first
/// test whether this is of type , then if so get the property
@@ -166,12 +154,6 @@ internal bool IsBool
[BestFriend]
internal bool IsVector { get; }
- ///
- /// Equivalent to as .
- ///
- [BestFriend]
- internal VectorType AsVector => IsVector ? (VectorType)this : null;
-
///
/// For non-vector types, this returns the column type itself (i.e., return this).
///
diff --git a/src/Microsoft.ML.Core/Data/MetadataUtils.cs b/src/Microsoft.ML.Core/Data/MetadataUtils.cs
index 024953f173..723c38fc74 100644
--- a/src/Microsoft.ML.Core/Data/MetadataUtils.cs
+++ b/src/Microsoft.ML.Core/Data/MetadataUtils.cs
@@ -165,12 +165,12 @@ public static VectorType GetCategoricalType(int rangeCount)
return new VectorType(NumberType.I4, rangeCount, 2);
}
- private static volatile ColumnType _scoreColumnSetIdType;
+ private static volatile KeyType _scoreColumnSetIdType;
///
/// The type of the ScoreColumnSetId metadata.
///
- public static ColumnType ScoreColumnSetIdType
+ public static KeyType ScoreColumnSetIdType
{
get
{
diff --git a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs
index 868dd83b19..a6541c8b8b 100644
--- a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs
@@ -210,7 +210,7 @@ private void RunCore(IChannel ch)
private bool ShouldAddColumn(Schema schema, int i, uint scoreSet, bool outputNamesAndLabels)
{
uint scoreSetId = 0;
- if (schema.TryGetMetadata(MetadataUtils.ScoreColumnSetIdType.AsPrimitive, MetadataUtils.Kinds.ScoreColumnSetId, i, ref scoreSetId)
+ if (schema.TryGetMetadata(MetadataUtils.ScoreColumnSetIdType, MetadataUtils.Kinds.ScoreColumnSetId, i, ref scoreSetId)
&& scoreSetId == scoreSet)
{
return true;
diff --git a/src/Microsoft.ML.Data/Data/Conversion.cs b/src/Microsoft.ML.Data/Data/Conversion.cs
index 1cfe187f24..e80135d6c8 100644
--- a/src/Microsoft.ML.Data/Data/Conversion.cs
+++ b/src/Microsoft.ML.Data/Data/Conversion.cs
@@ -412,15 +412,12 @@ public bool TryGetStandardConversion(ColumnType typeSrc, ColumnType typeDst,
conv = null;
identity = false;
- if (typeSrc.IsKey)
+ if (typeSrc is KeyType keySrc)
{
- var keySrc = typeSrc.AsKey;
-
// Key types are only convertable to compatible key types or unsigned integer
// types that are large enough.
- if (typeDst.IsKey)
+ if (typeDst is KeyType keyDst)
{
- var keyDst = typeDst.AsKey;
// We allow the Min value to shift. We currently don't allow the counts to vary.
// REVIEW: Should we allow the counts to vary? Allowing the dst to be bigger is trivial.
// Smaller dst means mapping values to NA.
@@ -451,11 +448,11 @@ public bool TryGetStandardConversion(ColumnType typeSrc, ColumnType typeDst,
// REVIEW: Should we look for illegal values and force them to zero? If so, then
// we'll need to set identity to false.
}
- else if (typeDst.IsKey)
+ else if (typeDst is KeyType keyDst)
{
if (!typeSrc.IsText)
return false;
- conv = GetKeyParse(typeDst.AsKey);
+ conv = GetKeyParse(keyDst);
return true;
}
else if (!typeDst.IsStandardScalar)
@@ -490,10 +487,10 @@ public bool TryGetStringConversion(ColumnType type, out ValueMapper(type.AsKey);
+ conv = GetKeyStringConversion(keyType);
return true;
}
return TryGetStringConversion(out conv);
@@ -572,8 +569,8 @@ public TryParseMapper GetTryParseConversion(ColumnType typeDst)
"Parse conversion only supported for standard types");
Contracts.Check(typeDst.RawType == typeof(TDst), "Wrong TDst type parameter");
- if (typeDst.IsKey)
- return GetKeyTryParse(typeDst.AsKey);
+ if (typeDst is KeyType keyType)
+ return GetKeyTryParse(keyType);
Contracts.Assert(_tryParseDelegates.ContainsKey(typeDst.RawKind));
return (TryParseMapper)_tryParseDelegates[typeDst.RawKind];
diff --git a/src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs b/src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs
index c67c6ae733..604fe3c2b7 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs
@@ -55,11 +55,11 @@ private static ColumnType MakeColumnType(SchemaShape.Column inputCol)
{
ColumnType curType = inputCol.ItemType;
if (inputCol.IsKey)
- curType = new KeyType(curType.AsPrimitive.RawKind, 0, AllKeySizes);
+ curType = new KeyType(((PrimitiveType)curType).RawKind, 0, AllKeySizes);
if (inputCol.Kind == SchemaShape.Column.VectorKind.VariableVector)
- curType = new VectorType(curType.AsPrimitive, 0);
+ curType = new VectorType((PrimitiveType)curType, 0);
else if (inputCol.Kind == SchemaShape.Column.VectorKind.Vector)
- curType = new VectorType(curType.AsPrimitive, AllVectorSizes);
+ curType = new VectorType((PrimitiveType)curType, AllVectorSizes);
return curType;
}
diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs
index a3d5260531..7e8a90b75e 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs
@@ -840,9 +840,8 @@ public void Save(ModelSaveContext ctx)
Contracts.Assert((DataKind)(byte)type.RawKind == type.RawKind);
ctx.Writer.Write((byte)type.RawKind);
ctx.Writer.WriteBoolByte(type.IsKey);
- if (type.IsKey)
+ if (type is KeyType key)
{
- var key = type.AsKey;
ctx.Writer.WriteBoolByte(key.Contiguous);
ctx.Writer.Write(key.Min);
ctx.Writer.Write(key.Count);
diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs
index cdcb507f51..f9d6cd2b09 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs
@@ -663,12 +663,14 @@ public Parser(TextLoader parent)
{
var info = _infos[i];
- if (info.ColType.ItemType.IsKey)
+ if (info.ColType is KeyType keyType)
{
- if (!info.ColType.IsVector)
- _creator[i] = cache.GetCreatorOne(info.ColType.AsKey);
- else
- _creator[i] = cache.GetCreatorVec(info.ColType.ItemType.AsKey);
+ _creator[i] = cache.GetCreatorOne(keyType);
+ continue;
+ }
+ else if (info.ColType is VectorType vectorType && vectorType.ItemType is KeyType vectorKeyType)
+ {
+ _creator[i] = cache.GetCreatorVec(vectorKeyType);
continue;
}
diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs
index 03fd7dfa70..08bc2c3e3f 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs
@@ -489,9 +489,8 @@ private TextLoader.Column GetColumn(string name, ColumnType type, int? start)
{
DataKind? kind;
KeyRange keyRange = null;
- if (type.ItemType.IsKey)
+ if (type.ItemType is KeyType key)
{
- var key = type.ItemType.AsKey;
if (!key.Contiguous)
keyRange = new KeyRange(key.Min, contiguous: false);
else if (key.Count == 0)
diff --git a/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs
index fe63426c61..83617f0f52 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs
@@ -659,7 +659,7 @@ public VectorType GetSlotType(int col)
var view = _parent._entries[col].GetViewOrNull();
if (view == null)
return null;
- return view.Schema.GetColumnType(0).AsVector;
+ return view.Schema.GetColumnType(0) as VectorType;
}
}
diff --git a/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeSaver.cs b/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeSaver.cs
index 9ce4ab3872..6878fbf935 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeSaver.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeSaver.cs
@@ -70,7 +70,7 @@ public bool IsColumnSavable(ColumnType type)
// an artificial vector type out of this. Obviously if you can't make a vector
// out of the items, then you could not save each slot's values.
var itemType = type.ItemType;
- var primitiveType = itemType.AsPrimitive;
+ var primitiveType = itemType as PrimitiveType;
if (primitiveType == null)
return false;
var vectorType = new VectorType(primitiveType, size: 2);
diff --git a/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs b/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs
index 756876dd97..62d28e1809 100644
--- a/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs
+++ b/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs
@@ -208,12 +208,12 @@ private Delegate CreateGetter(ColumnType colType, InternalSchemaDefinition.Colum
else
Host.Assert(colType.RawType == outputType);
- if (!colType.IsKey)
+ if (!(colType is KeyType keyType))
del = CreateDirectGetterDelegate;
else
{
var keyRawType = colType.RawType;
- Host.Assert(colType.AsKey.Contiguous);
+ Host.Assert(keyType.Contiguous);
Func delForKey = CreateKeyGetterDelegate;
return Utils.MarshalInvoke(delForKey, keyRawType, peek, colType);
}
@@ -299,9 +299,10 @@ private Delegate CreateDirectGetterDelegate(Delegate peekDel)
private Delegate CreateKeyGetterDelegate(Delegate peekDel, ColumnType colType)
{
// Make sure the function is dealing with key.
- Host.Check(colType.IsKey);
+ KeyType keyType = colType as KeyType;
+ Host.Check(keyType != null);
// Following equations work only with contiguous key type.
- Host.Check(colType.AsKey.Contiguous);
+ Host.Check(keyType.Contiguous);
// Following equations work only with unsigned integers.
Host.Check(typeof(TDst) == typeof(ulong) || typeof(TDst) == typeof(uint) ||
typeof(TDst) == typeof(byte) || typeof(TDst) == typeof(bool));
@@ -312,8 +313,8 @@ private Delegate CreateKeyGetterDelegate(Delegate peekDel, ColumnType colT
TDst rawKeyValue = default;
ulong key = 0; // the raw key value as ulong
- ulong min = colType.AsKey.Min;
- ulong max = min + (ulong)colType.AsKey.Count - 1;
+ ulong min = keyType.Min;
+ ulong max = min + (ulong)keyType.Count - 1;
ulong result = 0; // the result as ulong
ValueGetter getter = (ref TDst dst) =>
{
diff --git a/src/Microsoft.ML.Data/DataView/Transposer.cs b/src/Microsoft.ML.Data/DataView/Transposer.cs
index 901c76275a..cbd5dff214 100644
--- a/src/Microsoft.ML.Data/DataView/Transposer.cs
+++ b/src/Microsoft.ML.Data/DataView/Transposer.cs
@@ -306,8 +306,9 @@ public SchemaImpl(Transposer parent)
{
ColumnInfo srcInfo = _parent._cols[c];
var ctype = srcInfo.Type.ItemType;
- _ectx.Assert(ctype.IsPrimitive);
- _slotTypes[c] = new VectorType(ctype.AsPrimitive, _parent.RowCount);
+ var primitiveType = ctype as PrimitiveType;
+ _ectx.Assert(primitiveType != null);
+ _slotTypes[c] = new VectorType(primitiveType, _parent.RowCount);
}
AsSchema = Schema.Create(this);
@@ -1189,7 +1190,7 @@ private sealed class ColumnSplitter : Splitter
public ColumnSplitter(IDataView view, int col, int[] lims)
: base(view, col)
{
- var type = _view.Schema.GetColumnType(SrcCol).AsVector;
+ var type = _view.Schema.GetColumnType(SrcCol) as VectorType;
// Only valid use is for two or more slices.
Contracts.Assert(Utils.Size(lims) >= 2);
Contracts.AssertValue(type);
diff --git a/src/Microsoft.ML.Data/EntryPoints/ScoreColumnSelector.cs b/src/Microsoft.ML.Data/EntryPoints/ScoreColumnSelector.cs
index b2d2eb8450..f68fcff00a 100644
--- a/src/Microsoft.ML.Data/EntryPoints/ScoreColumnSelector.cs
+++ b/src/Microsoft.ML.Data/EntryPoints/ScoreColumnSelector.cs
@@ -44,7 +44,7 @@ public static CommonOutputs.TransformOutput SelectColumns(IHostEnvironment env,
private static bool ShouldAddColumn(Schema schema, int i, string[] extraColumns, uint scoreSet)
{
uint scoreSetId = 0;
- if (schema.TryGetMetadata(MetadataUtils.ScoreColumnSetIdType.AsPrimitive, MetadataUtils.Kinds.ScoreColumnSetId, i, ref scoreSetId)
+ if (schema.TryGetMetadata(MetadataUtils.ScoreColumnSetIdType, MetadataUtils.Kinds.ScoreColumnSetId, i, ref scoreSetId)
&& scoreSetId == scoreSet)
{
return true;
diff --git a/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs b/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs
index 0eb63d3f0c..22703accc6 100644
--- a/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs
+++ b/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs
@@ -723,7 +723,7 @@ public static void ReconcileVectorKeyValues(IHostEnvironment env, IDataView[] vi
(ref VBuffer> dst) => schema.GetMetadata(MetadataUtils.Kinds.SlotNames, index, ref dst);
}
views[i] = LambdaColumnMapper.Create(env, "ReconcileKeyValues", views[i], columnName, columnName,
- type, new VectorType(keyType, type.AsVector), mapper, keyValueGetter, slotNamesGetter);
+ type, new VectorType(keyType, type as VectorType), mapper, keyValueGetter, slotNamesGetter);
}
}
@@ -974,7 +974,7 @@ private static bool VerifyVectorColumnsMatch(int cachedSize, int col, IDataView
private static IDataView AddVarLengthColumn(IHostEnvironment env, IDataView idv, string variableSizeVectorColumnName, ColumnType typeSrc)
{
return LambdaColumnMapper.Create(env, "ChangeToVarLength", idv, variableSizeVectorColumnName,
- variableSizeVectorColumnName + "_VarLength", typeSrc, new VectorType(typeSrc.ItemType.AsPrimitive),
+ variableSizeVectorColumnName + "_VarLength", typeSrc, new VectorType((PrimitiveType)typeSrc.ItemType),
(in VBuffer src, ref VBuffer dst) => src.CopyTo(ref dst));
}
diff --git a/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs
index 12aaa47ad0..916f8ddc72 100644
--- a/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs
@@ -1001,11 +1001,11 @@ protected override IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMa
if (!perInst.Schema.TryGetColumnIndex(schema.Label.Name, out int labelCol))
throw Host.Except("Could not find column '{0}'", schema.Label.Name);
var labelType = perInst.Schema.GetColumnType(labelCol);
- if (labelType.IsKey && (!perInst.Schema.HasKeyValues(labelCol, labelType.KeyCount) || labelType.RawKind != DataKind.U4))
+ if (labelType is KeyType keyType && (!perInst.Schema.HasKeyValues(labelCol, keyType.KeyCount) || labelType.RawKind != DataKind.U4))
{
perInst = LambdaColumnMapper.Create(Host, "ConvertToDouble", perInst, schema.Label.Name,
schema.Label.Name, perInst.Schema.GetColumnType(labelCol), NumberType.R8,
- (in uint src, ref double dst) => dst = src == 0 ? double.NaN : src - 1 + (double)labelType.AsKey.Min);
+ (in uint src, ref double dst) => dst = src == 0 ? double.NaN : src - 1 + (double)keyType.Min);
}
var perInstSchema = perInst.Schema;
diff --git a/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs
index fa2de9c726..f33fcaee1a 100644
--- a/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs
+++ b/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs
@@ -549,7 +549,7 @@ private void CheckInputColumnTypes(Schema schema, out ColumnType labelType, out
var t = schema.GetColumnType(LabelIndex);
if (!t.IsKnownSizeVector || (t.ItemType != NumberType.R4 && t.ItemType != NumberType.R8))
throw Host.Except("Label column '{0}' has type '{1}' but must be a known-size vector of R4 or R8", LabelCol, t);
- labelType = new VectorType(t.ItemType.AsPrimitive, t.VectorSize);
+ labelType = new VectorType((PrimitiveType)t.ItemType, t.VectorSize);
var slotNamesType = new VectorType(TextType.Instance, t.VectorSize);
var builder = new MetadataBuilder();
builder.AddSlotNames(t.VectorSize, CreateSlotNamesGetter(schema, LabelIndex, labelType.VectorSize, "True"));
@@ -558,7 +558,7 @@ private void CheckInputColumnTypes(Schema schema, out ColumnType labelType, out
t = schema.GetColumnType(ScoreIndex);
if (t.VectorSize == 0 || t.ItemType != NumberType.Float)
throw Host.Except("Score column '{0}' has type '{1}' but must be a known length vector of type R4", ScoreCol, t);
- scoreType = new VectorType(t.ItemType.AsPrimitive, t.VectorSize);
+ scoreType = new VectorType((PrimitiveType)t.ItemType, t.VectorSize);
builder = new MetadataBuilder();
builder.AddSlotNames(t.VectorSize, CreateSlotNamesGetter(schema, ScoreIndex, scoreType.VectorSize, "Predicted"));
diff --git a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs
index 84acb75898..4e25546ae5 100644
--- a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs
+++ b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs
@@ -123,7 +123,7 @@ private static ISchemaBoundMapper WrapCore(IHostEnvironment env, ISchemaBound
trainSchema.Label.Index, ref value);
};
- return MultiClassClassifierScorer.LabelNameBindableMapper.CreateBound(env, (ISchemaBoundRowMapper)mapper, type.AsVector, getter, MetadataUtils.Kinds.TrainingLabelValues, CanWrap);
+ return MultiClassClassifierScorer.LabelNameBindableMapper.CreateBound(env, (ISchemaBoundRowMapper)mapper, type as VectorType, getter, MetadataUtils.Kinds.TrainingLabelValues, CanWrap);
}
public BinaryClassifierScorer(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema)
diff --git a/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculationTransform.cs b/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculationTransform.cs
index 70638e35a9..28cc55f5e5 100644
--- a/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculationTransform.cs
+++ b/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculationTransform.cs
@@ -365,7 +365,7 @@ public RowMapper(IHostEnvironment env, BindableMapper parent, RoleMappedSchema s
else
{
_outputSchema = Schema.Create(new FeatureContributionSchema(_env, DefaultColumnNames.FeatureContributions,
- new VectorType(NumberType.R4, schema.Feature.Type.AsVector),
+ new VectorType(NumberType.R4, schema.Feature.Type as VectorType),
InputSchema, InputRoleMappedSchema.Feature.Index));
}
diff --git a/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs
index 4b190636f9..d5b7ff9669 100644
--- a/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs
+++ b/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs
@@ -130,9 +130,9 @@ private LabelNameBindableMapper(IHost host, ModelLoadContext ctx)
ColumnType type;
object value;
_host.CheckDecode(saver.TryLoadTypeAndValue(ctx.Reader.BaseStream, out type, out value));
- _host.CheckDecode(type.IsVector);
+ _type = type as VectorType;
+ _host.CheckDecode(_type != null);
_host.CheckDecode(value != null);
- _type = type.AsVector;
_getter = Utils.MarshalInvoke(DecodeInit, _type.ItemType.RawType, value);
_metadataKind = ctx.Header.ModelVerReadable >= VersionAddedMetadataKind ?
ctx.LoadNonEmptyString() : MetadataUtils.Kinds.SlotNames;
@@ -493,7 +493,7 @@ public static ISchemaBoundMapper WrapCore(IHostEnvironment env, ISchemaBoundM
trainSchema.Label.Index, ref value);
};
- return LabelNameBindableMapper.CreateBound(env, (ISchemaBoundRowMapper)mapper, type.AsVector, getter, MetadataUtils.Kinds.SlotNames, CanWrap);
+ return LabelNameBindableMapper.CreateBound(env, (ISchemaBoundRowMapper)mapper, type as VectorType, getter, MetadataUtils.Kinds.SlotNames, CanWrap);
}
public MultiClassClassifierScorer(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema)
diff --git a/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs b/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs
index 73d37120c8..0d94802771 100644
--- a/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs
+++ b/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs
@@ -496,7 +496,7 @@ private BoundColumn MakeColumn(Schema inputSchema, int iinfo)
hasSlotNames = false;
}
- return new BoundColumn(InputSchema, _parent._columns[iinfo], sources, new VectorType(itemType.AsPrimitive, totalSize),
+ return new BoundColumn(InputSchema, _parent._columns[iinfo], sources, new VectorType((PrimitiveType)itemType, totalSize),
isNormalized, hasSlotNames, hasCategoricals, totalSize, catCount);
}
diff --git a/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs b/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs
index 65e5c52dc4..8e0f146b7c 100644
--- a/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs
@@ -518,7 +518,7 @@ private void ComputeType(Schema input, int iinfo, SlotDropper slotDropper,
Host.Assert(typeSrc.IsKnownSizeVector);
var dstLength = slotDropper.DstLength;
var hasSlotNames = input.HasSlotNames(_cols[iinfo], _srcTypes[iinfo].VectorSize);
- type = new VectorType(typeSrc.ItemType.AsPrimitive, Math.Max(dstLength, 1));
+ type = new VectorType((PrimitiveType)typeSrc.ItemType, Math.Max(dstLength, 1));
suppressed = dstLength == 0;
}
}
@@ -822,7 +822,7 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore()
{
var dstLength = _slotDropper[iinfo].DstLength;
var hasSlotNames = InputSchema.HasSlotNames(_cols[iinfo], _srcTypes[iinfo].VectorSize);
- var type = new VectorType(_srcTypes[iinfo].ItemType.AsPrimitive, Math.Max(dstLength, 1));
+ var type = new VectorType((PrimitiveType)_srcTypes[iinfo].ItemType, Math.Max(dstLength, 1));
if (hasSlotNames && dstLength > 0)
{
diff --git a/src/Microsoft.ML.Data/Transforms/Hashing.cs b/src/Microsoft.ML.Data/Transforms/Hashing.cs
index 9ea54f4897..8c40a0138e 100644
--- a/src/Microsoft.ML.Data/Transforms/Hashing.cs
+++ b/src/Microsoft.ML.Data/Transforms/Hashing.cs
@@ -911,7 +911,7 @@ private void AddMetaKeyValues(int i, MetadataBuilder builder)
{
_parent._keyValues[i].CopyTo(ref dst);
};
- builder.AddKeyValues(_parent._kvTypes[i].VectorSize, _parent._kvTypes[i].ItemType.AsPrimitive, getter);
+ builder.AddKeyValues(_parent._kvTypes[i].VectorSize, (PrimitiveType)_parent._kvTypes[i].ItemType, getter);
}
protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) => _parent.GetGetterCore(input, iinfo, out disposer);
diff --git a/src/Microsoft.ML.Data/Transforms/InvertHashUtils.cs b/src/Microsoft.ML.Data/Transforms/InvertHashUtils.cs
index 8d041dc6ce..8b683a73bf 100644
--- a/src/Microsoft.ML.Data/Transforms/InvertHashUtils.cs
+++ b/src/Microsoft.ML.Data/Transforms/InvertHashUtils.cs
@@ -41,12 +41,12 @@ public static ValueMapper GetSimpleMapper(Schema schema, in
var conv = Conversion.Conversions.Instance;
// First: if not key, then get the standard string converison.
- if (!type.IsKey)
+ if (!(type is KeyType keyType))
return conv.GetStringConversion(type);
bool identity;
// Second choice: if key, utilize the KeyValues metadata for that key, if it has one and is text.
- if (schema.HasKeyValues(col, type.KeyCount))
+ if (schema.HasKeyValues(col, keyType.KeyCount))
{
// REVIEW: Non-textual KeyValues are certainly possible. Should we handle them?
// Get the key names.
@@ -70,7 +70,7 @@ public static ValueMapper GetSimpleMapper(Schema schema, in
}
// Third choice: just use the key value itself, subject to offsetting by the min.
- return conv.GetKeyStringConversion(type.AsKey);
+ return conv.GetKeyStringConversion(keyType);
}
public static ValueMapper, StringBuilder> GetPairMapper(ValueMapper submap)
diff --git a/src/Microsoft.ML.Data/Transforms/KeyToValue.cs b/src/Microsoft.ML.Data/Transforms/KeyToValue.cs
index f2b4277059..b5de375360 100644
--- a/src/Microsoft.ML.Data/Transforms/KeyToValue.cs
+++ b/src/Microsoft.ML.Data/Transforms/KeyToValue.cs
@@ -232,10 +232,10 @@ private void ComputeKvMaps(Schema schema, out ColumnType[] types, out KeyToValue
var typeVals = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, ColMapNewToOld[iinfo]);
Host.Check(typeVals != null, "Metadata KeyValues does not exist");
Host.Check(typeVals.VectorSize == typeSrc.ItemType.KeyCount, "KeyValues metadata size does not match column type key count");
- if (!typeSrc.IsVector)
+ if (!(typeSrc is VectorType vectorType))
types[iinfo] = typeVals.ItemType;
else
- types[iinfo] = new VectorType(typeVals.ItemType.AsPrimitive, typeSrc.AsVector);
+ types[iinfo] = new VectorType((PrimitiveType)typeVals.ItemType, vectorType);
// MarshalInvoke with two generic params.
Func func = GetKeyMetadata;
@@ -258,7 +258,7 @@ private KeyToValueMap GetKeyMetadata(int iinfo, ColumnType typeKey
Host.Check(keyMetadata.Length == typeKey.ItemType.KeyCount);
VBufferUtils.Densify(ref keyMetadata);
- return new KeyToValueMap(this, typeKey.ItemType.AsKey, typeVal.ItemType.AsPrimitive, keyMetadata, iinfo);
+ return new KeyToValueMap(this, (KeyType)typeKey.ItemType, (PrimitiveType)typeVal.ItemType, keyMetadata, iinfo);
}
///
/// A map is an object capable of creating the association from an input type, to an output
diff --git a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs
index 1c781b3e32..493d313cc2 100644
--- a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs
+++ b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs
@@ -370,7 +370,7 @@ internal static bool GetNewType(IExceptionContext ectx, ColumnType srcType, Data
if (!srcType.ItemType.IsKey && !srcType.ItemType.IsText)
return false;
}
- else if (!srcType.ItemType.IsKey)
+ else if (!(srcType.ItemType is KeyType key))
itemType = PrimitiveType.FromKind(kind);
else if (!KeyType.IsValidDataKind(kind))
{
@@ -379,7 +379,6 @@ internal static bool GetNewType(IExceptionContext ectx, ColumnType srcType, Data
}
else
{
- var key = srcType.ItemType.AsKey;
ectx.Assert(KeyType.IsValidDataKind(key.RawKind));
int count = key.Count;
// Technically, it's an error for the counts not to match, but we'll let the Conversions
@@ -436,8 +435,8 @@ private static bool CanConvertToType(IExceptionContext ectx, ColumnType srcType,
return false;
typeDst = itemType;
- if (srcType.IsVector)
- typeDst = new VectorType(itemType, srcType.AsVector);
+ if (srcType is VectorType vectorType)
+ typeDst = new VectorType(itemType, vectorType);
return true;
}
@@ -473,9 +472,9 @@ protected override Delegate MakeGetter(Row input, int iinfo, Func act
Contracts.AssertValue(input);
Contracts.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length);
disposer = null;
- if (!_types[iinfo].IsVector)
+ if (!(_types[iinfo] is VectorType vectorType))
return RowCursorUtils.GetGetterAs(_types[iinfo], input, _srcCols[iinfo]);
- return RowCursorUtils.GetVecGetterAs(_types[iinfo].AsVector.ItemType, input, _srcCols[iinfo]);
+ return RowCursorUtils.GetVecGetterAs(vectorType.ItemType, input, _srcCols[iinfo]);
}
public void SaveAsOnnx(OnnxContext ctx)
@@ -507,7 +506,7 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName,
node.AddAttribute("to", (byte)_parent._columns[iinfo].OutputKind);
if (_parent._columns[iinfo].OutputKeyRange != null)
{
- var key = _types[iinfo].ItemType.AsKey;
+ var key = (KeyType)_types[iinfo].ItemType;
node.AddAttribute("min", key.Min);
node.AddAttribute("max", key.Count);
node.AddAttribute("contiguous", key.Contiguous);
diff --git a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs
index cda1d03401..b179f358dc 100644
--- a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs
+++ b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs
@@ -736,8 +736,8 @@ public Mapper(ValueToKeyMappingTransformer parent, Schema inputSchema)
var type = _infos[i].TypeSrc;
KeyType keyType = _parent._unboundMaps[i].OutputType;
ColumnType colType;
- if (type.IsVector)
- colType = new VectorType(keyType, type.AsVector);
+ if (type is VectorType vectorType)
+ colType = new VectorType(keyType, vectorType);
else
colType = keyType;
_types[i] = colType;
diff --git a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformerImpl.cs b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformerImpl.cs
index 87c112b87d..8c7961c969 100644
--- a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformerImpl.cs
+++ b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformerImpl.cs
@@ -49,7 +49,7 @@ public static Builder Create(ColumnType type, SortOrder sortOrder)
Contracts.Assert(sortOrder == SortOrder.Occurrence || sortOrder == SortOrder.Value);
bool sorted = sortOrder == SortOrder.Value;
- PrimitiveType itemType = type.ItemType.AsPrimitive;
+ PrimitiveType itemType = type.ItemType as PrimitiveType;
Contracts.AssertValue(itemType);
if (itemType.IsText)
return new TextImpl(sorted);
@@ -568,7 +568,7 @@ private static TermMap LoadCodecCore(ModelLoadContext ctx, IExceptionContext
}
}
- return new HashArrayImpl(codec.Type.AsPrimitive, values);
+ return new HashArrayImpl((PrimitiveType)codec.Type, values);
}
internal abstract void WriteTextTerms(TextWriter writer);
@@ -1101,7 +1101,7 @@ private bool AddMetadataCore(ColumnType srcMetaType, MetadataBuilder buil
_host.AssertValue(srcMetaType);
_host.Assert(srcMetaType.RawType == typeof(TMeta));
_host.AssertValue(builder);
- var srcType = TypedMap.ItemType.AsKey;
+ var srcType = TypedMap.ItemType as KeyType;
_host.AssertValue(srcType);
var dstType = new KeyType(DataKind.U4, srcType.Min, srcType.Count);
var convInst = Runtime.Data.Conversion.Conversions.Instance;
@@ -1156,7 +1156,7 @@ private bool AddMetadataCore(ColumnType srcMetaType, MetadataBuilder buil
getter(ref dst);
_host.Assert(dst.Length == TypedMap.OutputType.KeyCount);
};
- builder.AddKeyValues(TypedMap.OutputType.KeyCount, srcMetaType.ItemType.AsPrimitive, mgetter);
+ builder.AddKeyValues(TypedMap.OutputType.KeyCount, (PrimitiveType)srcMetaType.ItemType, mgetter);
}
return true;
}
@@ -1169,7 +1169,7 @@ public override void WriteTextTerms(TextWriter writer)
_schema.TryGetColumnIndex(_infos[_iinfo].Source, out int srcCol);
ColumnType srcMetaType = _schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, srcCol);
if (srcMetaType == null || srcMetaType.VectorSize != TypedMap.ItemType.KeyCount ||
- TypedMap.ItemType.KeyCount == 0 || !Utils.MarshalInvoke(WriteTextTermsCore, srcMetaType.ItemType.RawType, srcMetaType.AsVector.ItemType, writer))
+ TypedMap.ItemType.KeyCount == 0 || !Utils.MarshalInvoke(WriteTextTermsCore, srcMetaType.ItemType.RawType, ((VectorType)srcMetaType).ItemType, writer))
{
// No valid input key-value metadata. Back off to the base implementation.
base.WriteTextTerms(writer);
@@ -1180,7 +1180,7 @@ private bool WriteTextTermsCore(PrimitiveType srcMetaType, TextWriter wri
{
_host.AssertValue(srcMetaType);
_host.Assert(srcMetaType.RawType == typeof(TMeta));
- var srcType = TypedMap.ItemType.AsKey;
+ var srcType = TypedMap.ItemType as KeyType;
_host.AssertValue(srcType);
var dstType = new KeyType(DataKind.U4, srcType.Min, srcType.Count);
var convInst = Runtime.Data.Conversion.Conversions.Instance;
diff --git a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs
index dafa19eb61..d393e6c316 100644
--- a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs
+++ b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs
@@ -603,7 +603,7 @@ protected static int CheckLabelColumn(IHostEnvironment env, IPredictorModel[] mo
if (mdType == null || !mdType.IsKnownSizeVector)
throw env.Except("Label column of type key must have a vector of key values metadata");
- return Utils.MarshalInvoke(CheckKeyLabelColumnCore, mdType.ItemType.RawType, env, models, labelType.AsKey, schema, labelInfo.Index, mdType);
+ return Utils.MarshalInvoke(CheckKeyLabelColumnCore, mdType.ItemType.RawType, env, models, (KeyType)labelType, schema, labelInfo.Index, mdType);
}
// When the label column is not a key, we check that the number of classes is the same for all the predictors, by checking the
@@ -655,8 +655,8 @@ private static int CheckKeyLabelColumnCore(IHostEnvironment env, IPredictorMo
if (labelInfo == null)
throw env.Except("Training schema for model {0} does not have a label column", i);
- var curLabelType = rmd.Schema.Schema.GetColumnType(rmd.Schema.Label.Index);
- if (!labelType.Equals(curLabelType.AsKey))
+ var curLabelType = rmd.Schema.Schema.GetColumnType(rmd.Schema.Label.Index) as KeyType;
+ if (!labelType.Equals(curLabelType))
throw env.Except("Label column of model {0} has different type than model 0", i);
var mdType = rmd.Schema.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, labelInfo.Index);
diff --git a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs
index 7eef450580..ec3827286c 100644
--- a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs
+++ b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs
@@ -788,7 +788,7 @@ private static IDataView AppendLabelTransform(IHostEnvironment env, IChannel ch,
"labelPermutationSeed != 0 only applies on a multi-class learning problem when the label type is a key.");
return input;
}
- return Utils.MarshalInvoke(AppendFloatMapper, labelType.RawType, env, ch, input, labelName, labelType.AsKey,
+ return Utils.MarshalInvoke(AppendFloatMapper, labelType.RawType, env, ch, input, labelName, (KeyType)labelType,
labelPermutationSeed);
}
}
diff --git a/src/Microsoft.ML.HalLearners/VectorWhitening.cs b/src/Microsoft.ML.HalLearners/VectorWhitening.cs
index 2a0358f74a..707e8c780f 100644
--- a/src/Microsoft.ML.HalLearners/VectorWhitening.cs
+++ b/src/Microsoft.ML.HalLearners/VectorWhitening.cs
@@ -323,7 +323,8 @@ protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol
// Check if the input column's type is supported. Note that only float vector with a known shape is allowed.
internal static string TestColumn(ColumnType type)
{
- if ((type.IsVector && !type.IsKnownSizeVector && (type.AsVector.Dimensions.Length > 1)) || type.ItemType != NumberType.R4)
+ if ((type is VectorType vectorType && !vectorType.IsKnownSizeVector && vectorType.Dimensions.Length > 1)
+ || type.ItemType != NumberType.R4)
return "Expected float or float vector of known size";
if ((long)type.ValueCount * type.ValueCount > Utils.ArrayMaxSize)
diff --git a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs
index 46300386ab..2216beadb7 100644
--- a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs
+++ b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs
@@ -139,14 +139,14 @@ protected override void ConvertNaNLabels(IChannel ch, RoleMappedData data, float
if (maxLabel >= _maxNumClass)
throw ch.ExceptParam(nameof(data), $"max labelColumn cannot exceed {_maxNumClass}");
- if (data.Schema.Label.Type.IsKey)
+ if (data.Schema.Label.Type is KeyType keyType)
{
- ch.Check(data.Schema.Label.Type.AsKey.Contiguous, "labelColumn value should be contiguous");
+ ch.Check(keyType.Contiguous, "labelColumn value should be contiguous");
if (hasNaNLabel)
- _numClass = data.Schema.Label.Type.AsKey.Count + 1;
+ _numClass = keyType.Count + 1;
else
- _numClass = data.Schema.Label.Type.AsKey.Count;
- _tlcNumClass = data.Schema.Label.Type.AsKey.Count;
+ _numClass = keyType.Count;
+ _tlcNumClass = keyType.Count;
}
else
{
diff --git a/src/Microsoft.ML.Onnx/OnnxUtils.cs b/src/Microsoft.ML.Onnx/OnnxUtils.cs
index 1d6c4b6fc1..877fc7b9ed 100644
--- a/src/Microsoft.ML.Onnx/OnnxUtils.cs
+++ b/src/Microsoft.ML.Onnx/OnnxUtils.cs
@@ -296,10 +296,10 @@ public static ModelArgs GetModelArgs(ColumnType type, string colName,
TensorProto.Types.DataType dataType = TensorProto.Types.DataType.Undefined;
DataKind rawKind;
- if (type.IsVector)
- rawKind = type.AsVector.ItemType.RawKind;
- else if (type.IsKey)
- rawKind = type.AsKey.RawKind;
+ if (type is VectorType vectorType)
+ rawKind = vectorType.ItemType.RawKind;
+ else if (type is KeyType keyType)
+ rawKind = keyType.RawKind;
else
rawKind = type.RawKind;
@@ -367,7 +367,7 @@ public static ModelArgs GetModelArgs(ColumnType type, string colName,
dimsLocal.Add(1);
else if (type.ValueCount > 1)
{
- var vec = type.AsVector;
+ var vec = (VectorType)type;
for (int i = 0; i < vec.Dimensions.Length; i++)
dimsLocal.Add(vec.Dimensions[i]);
}
diff --git a/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs b/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs
index e42c6b002e..d858d5bef4 100644
--- a/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs
+++ b/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs
@@ -347,7 +347,7 @@ private MatrixFactorizationPredictor TrainCore(IChannel ch, RoleMappedData data,
using (var buffer = PrepareBuffer())
{
buffer.Train(ch, rowCount, colCount, cursor, labGetter, matrixRowIndexGetter, matrixColumnIndexGetter);
- predictor = new MatrixFactorizationPredictor(Host, buffer, matrixColumnIndexColInfo.Type.AsKey, matrixRowIndexColInfo.Type.AsKey);
+ predictor = new MatrixFactorizationPredictor(Host, buffer, (KeyType)matrixColumnIndexColInfo.Type, (KeyType)matrixRowIndexColInfo.Type);
}
}
else
@@ -365,7 +365,7 @@ private MatrixFactorizationPredictor TrainCore(IChannel ch, RoleMappedData data,
buffer.TrainWithValidation(ch, rowCount, colCount,
cursor, labGetter, matrixRowIndexGetter, matrixColumnIndexGetter,
validCursor, validLabelGetter, validMatrixRowIndexGetter, validMatrixColumnIndexGetter);
- predictor = new MatrixFactorizationPredictor(Host, buffer, matrixColumnIndexColInfo.Type.AsKey, matrixRowIndexColInfo.Type.AsKey);
+ predictor = new MatrixFactorizationPredictor(Host, buffer, (KeyType)matrixColumnIndexColInfo.Type, (KeyType)matrixRowIndexColInfo.Type);
}
}
}
diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs
index be73a2de30..821cb21cb0 100644
--- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs
+++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs
@@ -299,14 +299,16 @@ private FieldAwareFactorizationMachinePredictor TrainCore(IChannel ch, IProgress
for (int f = 0; f < fieldCount; f++)
{
var col = featureColumns[f];
- Host.Assert(col.Type.AsVector.VectorSize > 0);
if (col == null)
throw ch.ExceptParam(nameof(data), "Empty feature column not allowed");
Host.Assert(!data.Schema.Schema.IsHidden(col.Index));
- if (!col.Type.IsKnownSizeVector || col.Type.ItemType != NumberType.Float)
+ if (!(col.Type is VectorType vectorType) ||
+ !vectorType.IsKnownSizeVector ||
+ vectorType.ItemType != NumberType.Float)
throw ch.ExceptParam(nameof(data), "Training feature column '{0}' must be a known-size vector of R4, but has type: {1}.", col.Name, col.Type);
+ Host.Assert(vectorType.VectorSize > 0);
fieldColumnIndexes[f] = col.Index;
- totalFeatureCount += col.Type.AsVector.VectorSize;
+ totalFeatureCount += vectorType.VectorSize;
}
ch.Check(checked(totalFeatureCount * fieldCount * _latentDimAligned) <= Utils.ArrayMaxSize, "Latent dimension or the number of fields too large");
if (predictor != null)
diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
index b7a1ba901f..f555545099 100644
--- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
+++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
@@ -837,7 +837,7 @@ public Mapper(TensorFlowTransform parent, Schema inputSchema) :
else
{
if (shape.Select((dim, j) => dim != -1 && dim != colTypeDims[j]).Any(b => b))
- throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {originalShape.ToString()}, but input data is {type.AsVector.ToString()}.");
+ throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {originalShape.ToString()}, but input data is {vecType.ToString()}.");
// Fill in the unknown dimensions.
var l = new long[originalShape.NumDimensions];
diff --git a/src/Microsoft.ML.Transforms/GroupTransform.cs b/src/Microsoft.ML.Transforms/GroupTransform.cs
index 1f77e1dd16..1b9e341007 100644
--- a/src/Microsoft.ML.Transforms/GroupTransform.cs
+++ b/src/Microsoft.ML.Transforms/GroupTransform.cs
@@ -287,8 +287,9 @@ private static ColumnType[] BuildColumnTypes(Schema input, int[] ids)
for (int i = 0; i < ids.Length; i++)
{
var srcType = input.GetColumnType(ids[i]);
- Contracts.Assert(srcType.IsPrimitive);
- types[i] = new VectorType(srcType.AsPrimitive, size: 0);
+ var primitiveType = srcType as PrimitiveType;
+ Contracts.Assert(primitiveType != null);
+ types[i] = new VectorType(primitiveType, size: 0);
}
return types;
}
diff --git a/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs
index 77e89f8882..cc678facfd 100644
--- a/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs
+++ b/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs
@@ -164,7 +164,7 @@ public Mapper(MissingValueDroppingTransformer parent, Schema inputSchema) :
inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out _srcCols[i]);
var srcCol = inputSchema[_srcCols[i]];
_srcTypes[i] = srcCol.Type;
- _types[i] = new VectorType(srcCol.Type.ItemType.AsPrimitive);
+ _types[i] = new VectorType((PrimitiveType)srcCol.Type.ItemType);
_isNAs[i] = GetIsNADelegate(srcCol.Type);
}
}
diff --git a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransform.cs b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransform.cs
index a58603c16e..09ab2853da 100644
--- a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransform.cs
+++ b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransform.cs
@@ -140,11 +140,11 @@ private VectorType[] GetTypesAndMetadata()
// This ensures that our feature count doesn't overflow.
Host.Check(type.ValueCount < int.MaxValue / 2);
- if (!type.IsVector)
+ if (!(type is VectorType vectorType))
types[iinfo] = new VectorType(NumberType.Float, 2);
else
{
- types[iinfo] = new VectorType(NumberType.Float, type.AsVector, 2);
+ types[iinfo] = new VectorType(NumberType.Float, vectorType, 2);
// Produce slot names metadata iff the source has (valid) slot names.
ColumnType typeNames;
diff --git a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs
index 73efd1d719..8e2580accb 100644
--- a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs
+++ b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs
@@ -182,10 +182,10 @@ private ColInfo[] CreateInfos(Schema inputSchema)
_parent.CheckInputColumn(inputSchema, i, colSrc);
var inType = inputSchema.GetColumnType(colSrc);
ColumnType outType;
- if (!inType.IsVector)
+ if (!(inType is VectorType vectorType))
outType = BoolType.Instance;
else
- outType = new VectorType(BoolType.Instance, inType.AsVector);
+ outType = new VectorType(BoolType.Instance, vectorType);
infos[i] = new ColInfo(_parent.ColumnPairs[i].input, _parent.ColumnPairs[i].output, inType, outType);
}
return infos;
@@ -469,7 +469,9 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.SlotNames, out var slotMeta))
metadata.Add(slotMeta);
metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false));
- ColumnType type = !col.ItemType.IsVector ? (ColumnType)BoolType.Instance : new VectorType(BoolType.Instance, col.ItemType.AsVector);
+ ColumnType type = !(col.ItemType is VectorType vectorType) ?
+ (ColumnType)BoolType.Instance :
+ new VectorType(BoolType.Instance, vectorType);
result[colPair.output] = new SchemaShape.Column(colPair.output, col.Kind, type, false, new SchemaShape(metadata.ToArray()));
}
return new SchemaShape(result.Values);
diff --git a/src/Microsoft.ML.Transforms/MissingValueReplacing.cs b/src/Microsoft.ML.Transforms/MissingValueReplacing.cs
index 296aed0eab..a6579cb9b2 100644
--- a/src/Microsoft.ML.Transforms/MissingValueReplacing.cs
+++ b/src/Microsoft.ML.Transforms/MissingValueReplacing.cs
@@ -324,8 +324,8 @@ private void GetReplacementValues(IDataView input, ColumnInfo[] columns, out obj
input.Schema.TryGetColumnIndex(columns[iinfo].Input, out int colSrc);
sources[iinfo] = colSrc;
var type = input.Schema.GetColumnType(colSrc);
- if (type.IsVector)
- type = new VectorType(type.ItemType.AsPrimitive, type.AsVector);
+ if (type is VectorType vectorType)
+ type = new VectorType((PrimitiveType)type.ItemType, vectorType);
Delegate isNa = GetIsNADelegate(type);
types[iinfo] = type;
var kind = (ReplacementKind)columns[iinfo].Replacement;
@@ -593,8 +593,8 @@ public Mapper(MissingValueReplacingTransformer parent, Schema inputSchema)
for (int i = 0; i < _parent.ColumnPairs.Length; i++)
{
var type = _infos[i].TypeSrc;
- if (type.IsVector)
- type = new VectorType(type.ItemType.AsPrimitive, type.AsVector);
+ if (type is VectorType vectorType)
+ type = new VectorType((PrimitiveType)type.ItemType, vectorType);
var repType = _parent._repIsDefault[i] != null ? _parent._replaceTypes[i] : _parent._replaceTypes[i].ItemType;
if (!type.ItemType.Equals(repType.ItemType))
throw Host.ExceptParam(nameof(InputSchema), "Column '{0}' item type '{1}' does not match expected ColumnType of '{2}'",
@@ -897,10 +897,10 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src
{
DataKind rawKind;
var type = _infos[iinfo].TypeSrc;
- if (type.IsVector)
- rawKind = type.AsVector.ItemType.RawKind;
- else if (type.IsKey)
- rawKind = type.AsKey.RawKind;
+ if (type is VectorType vectorType)
+ rawKind = vectorType.ItemType.RawKind;
+ else if (type is KeyType keyType)
+ rawKind = keyType.RawKind;
else
rawKind = type.RawKind;
@@ -965,7 +965,9 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
metadata.Add(slotMeta);
if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.IsNormalized, out var normalized))
metadata.Add(normalized);
- var type = !col.ItemType.IsVector ? col.ItemType : new VectorType(col.ItemType.ItemType.AsPrimitive, col.ItemType.AsVector);
+ var type = !(col.ItemType is VectorType vectorType) ?
+ col.ItemType :
+ new VectorType((PrimitiveType)col.ItemType.ItemType, vectorType);
result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, col.Kind, type, false, new SchemaShape(metadata.ToArray()));
}
return new SchemaShape(result.Values);
diff --git a/src/Microsoft.ML.Transforms/UngroupTransform.cs b/src/Microsoft.ML.Transforms/UngroupTransform.cs
index e52f6d5354..5354dfda0b 100644
--- a/src/Microsoft.ML.Transforms/UngroupTransform.cs
+++ b/src/Microsoft.ML.Transforms/UngroupTransform.cs
@@ -291,7 +291,7 @@ private static void CheckAndBind(IExceptionContext ectx, Schema inputSchema,
if (!colType.IsVector || !colType.ItemType.IsPrimitive)
throw ectx.ExceptUserArg(nameof(Arguments.Column),
"Pivot column '{0}' has type '{1}', but must be a vector of primitive types", name, colType);
- infos[i] = new PivotColumnInfo(name, col, colType.VectorSize, colType.ItemType.AsPrimitive);
+ infos[i] = new PivotColumnInfo(name, col, colType.VectorSize, (PrimitiveType)colType.ItemType);
}
}