Skip to content

Remove As methods on ColumnType. #1864

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 12, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 0 additions & 18 deletions src/Microsoft.ML.Core/Data/ColumnType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,6 @@ private protected ColumnType(Type rawType, DataKind rawKind)
[BestFriend]
internal bool IsPrimitive { get; }

/// <summary>
/// Equivalent to <c>as <see cref="PrimitiveType"/></c>.
/// </summary>
[BestFriend]
internal PrimitiveType AsPrimitive => IsPrimitive ? (PrimitiveType)this : null;

/// <summary>
/// Whether this type is a standard numeric type. External code should use <c>is <see cref="NumberType"/></c>.
/// </summary>
Expand Down Expand Up @@ -140,12 +134,6 @@ internal bool IsBool
[BestFriend]
internal bool IsKey { get; }

/// <summary>
/// Equivalent to <c>as <see cref="KeyType"/></c>.
/// </summary>
[BestFriend]
internal KeyType AsKey => IsKey ? (KeyType)this : null;

/// <summary>
/// Zero return means either it's not a key type or the cardinality is unknown. External code should first
/// test whether this is of type <see cref="KeyType"/>, then if so get the <see cref="KeyType.Count"/> property
Expand All @@ -166,12 +154,6 @@ internal bool IsBool
[BestFriend]
internal bool IsVector { get; }

/// <summary>
/// Equivalent to <c>as <see cref="VectorType"/></c>.
/// </summary>
[BestFriend]
internal VectorType AsVector => IsVector ? (VectorType)this : null;

/// <summary>
/// For non-vector types, this returns the column type itself (i.e., return <c>this</c>).
/// </summary>
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Core/Data/MetadataUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,12 @@ public static VectorType GetCategoricalType(int rangeCount)
return new VectorType(NumberType.I4, rangeCount, 2);
}

private static volatile ColumnType _scoreColumnSetIdType;
private static volatile KeyType _scoreColumnSetIdType;

/// <summary>
/// The type of the ScoreColumnSetId metadata.
/// </summary>
public static ColumnType ScoreColumnSetIdType
public static KeyType ScoreColumnSetIdType
{
get
{
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Commands/ScoreCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ private void RunCore(IChannel ch)
private bool ShouldAddColumn(Schema schema, int i, uint scoreSet, bool outputNamesAndLabels)
{
uint scoreSetId = 0;
if (schema.TryGetMetadata(MetadataUtils.ScoreColumnSetIdType.AsPrimitive, MetadataUtils.Kinds.ScoreColumnSetId, i, ref scoreSetId)
if (schema.TryGetMetadata(MetadataUtils.ScoreColumnSetIdType, MetadataUtils.Kinds.ScoreColumnSetId, i, ref scoreSetId)
&& scoreSetId == scoreSet)
{
return true;
Expand Down
19 changes: 8 additions & 11 deletions src/Microsoft.ML.Data/Data/Conversion.cs
Original file line number Diff line number Diff line change
Expand Up @@ -412,15 +412,12 @@ public bool TryGetStandardConversion(ColumnType typeSrc, ColumnType typeDst,

conv = null;
identity = false;
if (typeSrc.IsKey)
if (typeSrc is KeyType keySrc)
{
var keySrc = typeSrc.AsKey;

// Key types are only convertable to compatible key types or unsigned integer
// types that are large enough.
if (typeDst.IsKey)
if (typeDst is KeyType keyDst)
{
var keyDst = typeDst.AsKey;
// We allow the Min value to shift. We currently don't allow the counts to vary.
// REVIEW: Should we allow the counts to vary? Allowing the dst to be bigger is trivial.
// Smaller dst means mapping values to NA.
Expand Down Expand Up @@ -451,11 +448,11 @@ public bool TryGetStandardConversion(ColumnType typeSrc, ColumnType typeDst,
// REVIEW: Should we look for illegal values and force them to zero? If so, then
// we'll need to set identity to false.
}
else if (typeDst.IsKey)
else if (typeDst is KeyType keyDst)
{
if (!typeSrc.IsText)
return false;
conv = GetKeyParse(typeDst.AsKey);
conv = GetKeyParse(keyDst);
return true;
}
else if (!typeDst.IsStandardScalar)
Expand Down Expand Up @@ -490,10 +487,10 @@ public bool TryGetStringConversion<TSrc>(ColumnType type, out ValueMapper<TSrc,
Contracts.CheckValue(type, nameof(type));
Contracts.Check(type.RawType == typeof(TSrc), "Wrong TSrc type argument");

if (type.IsKey)
if (type is KeyType keyType)
{
// Key string conversion always works.
conv = GetKeyStringConversion<TSrc>(type.AsKey);
conv = GetKeyStringConversion<TSrc>(keyType);
return true;
}
return TryGetStringConversion(out conv);
Expand Down Expand Up @@ -572,8 +569,8 @@ public TryParseMapper<TDst> GetTryParseConversion<TDst>(ColumnType typeDst)
"Parse conversion only supported for standard types");
Contracts.Check(typeDst.RawType == typeof(TDst), "Wrong TDst type parameter");

if (typeDst.IsKey)
return GetKeyTryParse<TDst>(typeDst.AsKey);
if (typeDst is KeyType keyType)
return GetKeyTryParse<TDst>(keyType);

Contracts.Assert(_tryParseDelegates.ContainsKey(typeDst.RawKind));
return (TryParseMapper<TDst>)_tryParseDelegates[typeDst.RawKind];
Expand Down
6 changes: 3 additions & 3 deletions src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ private static ColumnType MakeColumnType(SchemaShape.Column inputCol)
{
ColumnType curType = inputCol.ItemType;
if (inputCol.IsKey)
curType = new KeyType(curType.AsPrimitive.RawKind, 0, AllKeySizes);
curType = new KeyType(((PrimitiveType)curType).RawKind, 0, AllKeySizes);
if (inputCol.Kind == SchemaShape.Column.VectorKind.VariableVector)
curType = new VectorType(curType.AsPrimitive, 0);
curType = new VectorType((PrimitiveType)curType, 0);
else if (inputCol.Kind == SchemaShape.Column.VectorKind.Vector)
curType = new VectorType(curType.AsPrimitive, AllVectorSizes);
curType = new VectorType((PrimitiveType)curType, AllVectorSizes);
return curType;
}

Expand Down
3 changes: 1 addition & 2 deletions src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -840,9 +840,8 @@ public void Save(ModelSaveContext ctx)
Contracts.Assert((DataKind)(byte)type.RawKind == type.RawKind);
ctx.Writer.Write((byte)type.RawKind);
ctx.Writer.WriteBoolByte(type.IsKey);
if (type.IsKey)
if (type is KeyType key)
{
var key = type.AsKey;
ctx.Writer.WriteBoolByte(key.Contiguous);
ctx.Writer.Write(key.Min);
ctx.Writer.Write(key.Count);
Expand Down
12 changes: 7 additions & 5 deletions src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -663,12 +663,14 @@ public Parser(TextLoader parent)
{
var info = _infos[i];

if (info.ColType.ItemType.IsKey)
if (info.ColType is KeyType keyType)
{
if (!info.ColType.IsVector)
_creator[i] = cache.GetCreatorOne(info.ColType.AsKey);
else
_creator[i] = cache.GetCreatorVec(info.ColType.ItemType.AsKey);
_creator[i] = cache.GetCreatorOne(keyType);
continue;
}
else if (info.ColType is VectorType vectorType && vectorType.ItemType is KeyType vectorKeyType)
{
_creator[i] = cache.GetCreatorVec(vectorKeyType);
continue;
}

Expand Down
3 changes: 1 addition & 2 deletions src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -489,9 +489,8 @@ private TextLoader.Column GetColumn(string name, ColumnType type, int? start)
{
DataKind? kind;
KeyRange keyRange = null;
if (type.ItemType.IsKey)
if (type.ItemType is KeyType key)
{
var key = type.ItemType.AsKey;
if (!key.Contiguous)
keyRange = new KeyRange(key.Min, contiguous: false);
else if (key.Count == 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@ public VectorType GetSlotType(int col)
var view = _parent._entries[col].GetViewOrNull();
if (view == null)
return null;
return view.Schema.GetColumnType(0).AsVector;
return view.Schema.GetColumnType(0) as VectorType;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public bool IsColumnSavable(ColumnType type)
// an artificial vector type out of this. Obviously if you can't make a vector
// out of the items, then you could not save each slot's values.
var itemType = type.ItemType;
var primitiveType = itemType.AsPrimitive;
var primitiveType = itemType as PrimitiveType;
if (primitiveType == null)
return false;
var vectorType = new VectorType(primitiveType, size: 2);
Expand Down
13 changes: 7 additions & 6 deletions src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -208,12 +208,12 @@ private Delegate CreateGetter(ColumnType colType, InternalSchemaDefinition.Colum
else
Host.Assert(colType.RawType == outputType);

if (!colType.IsKey)
if (!(colType is KeyType keyType))
del = CreateDirectGetterDelegate<int>;
else
{
var keyRawType = colType.RawType;
Host.Assert(colType.AsKey.Contiguous);
Host.Assert(keyType.Contiguous);
Func<Delegate, ColumnType, Delegate> delForKey = CreateKeyGetterDelegate<uint>;
return Utils.MarshalInvoke(delForKey, keyRawType, peek, colType);
}
Expand Down Expand Up @@ -299,9 +299,10 @@ private Delegate CreateDirectGetterDelegate<TDst>(Delegate peekDel)
private Delegate CreateKeyGetterDelegate<TDst>(Delegate peekDel, ColumnType colType)
{
// Make sure the function is dealing with key.
Host.Check(colType.IsKey);
KeyType keyType = colType as KeyType;
Host.Check(keyType != null);
// Following equations work only with contiguous key type.
Host.Check(colType.AsKey.Contiguous);
Host.Check(keyType.Contiguous);
// Following equations work only with unsigned integers.
Host.Check(typeof(TDst) == typeof(ulong) || typeof(TDst) == typeof(uint) ||
typeof(TDst) == typeof(byte) || typeof(TDst) == typeof(bool));
Expand All @@ -312,8 +313,8 @@ private Delegate CreateKeyGetterDelegate<TDst>(Delegate peekDel, ColumnType colT

TDst rawKeyValue = default;
ulong key = 0; // the raw key value as ulong
ulong min = colType.AsKey.Min;
ulong max = min + (ulong)colType.AsKey.Count - 1;
ulong min = keyType.Min;
ulong max = min + (ulong)keyType.Count - 1;
ulong result = 0; // the result as ulong
ValueGetter<TDst> getter = (ref TDst dst) =>
{
Expand Down
7 changes: 4 additions & 3 deletions src/Microsoft.ML.Data/DataView/Transposer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,9 @@ public SchemaImpl(Transposer parent)
{
ColumnInfo srcInfo = _parent._cols[c];
var ctype = srcInfo.Type.ItemType;
_ectx.Assert(ctype.IsPrimitive);
_slotTypes[c] = new VectorType(ctype.AsPrimitive, _parent.RowCount);
var primitiveType = ctype as PrimitiveType;
_ectx.Assert(primitiveType != null);
_slotTypes[c] = new VectorType(primitiveType, _parent.RowCount);
}

AsSchema = Schema.Create(this);
Expand Down Expand Up @@ -1189,7 +1190,7 @@ private sealed class ColumnSplitter<T> : Splitter
public ColumnSplitter(IDataView view, int col, int[] lims)
: base(view, col)
{
var type = _view.Schema.GetColumnType(SrcCol).AsVector;
var type = _view.Schema.GetColumnType(SrcCol) as VectorType;
// Only valid use is for two or more slices.
Contracts.Assert(Utils.Size(lims) >= 2);
Contracts.AssertValue(type);
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/EntryPoints/ScoreColumnSelector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public static CommonOutputs.TransformOutput SelectColumns(IHostEnvironment env,
private static bool ShouldAddColumn(Schema schema, int i, string[] extraColumns, uint scoreSet)
{
uint scoreSetId = 0;
if (schema.TryGetMetadata(MetadataUtils.ScoreColumnSetIdType.AsPrimitive, MetadataUtils.Kinds.ScoreColumnSetId, i, ref scoreSetId)
if (schema.TryGetMetadata(MetadataUtils.ScoreColumnSetIdType, MetadataUtils.Kinds.ScoreColumnSetId, i, ref scoreSetId)
&& scoreSetId == scoreSet)
{
return true;
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ public static void ReconcileVectorKeyValues(IHostEnvironment env, IDataView[] vi
(ref VBuffer<ReadOnlyMemory<char>> dst) => schema.GetMetadata(MetadataUtils.Kinds.SlotNames, index, ref dst);
}
views[i] = LambdaColumnMapper.Create(env, "ReconcileKeyValues", views[i], columnName, columnName,
type, new VectorType(keyType, type.AsVector), mapper, keyValueGetter, slotNamesGetter);
type, new VectorType(keyType, type as VectorType), mapper, keyValueGetter, slotNamesGetter);
}
}

Expand Down Expand Up @@ -974,7 +974,7 @@ private static bool VerifyVectorColumnsMatch(int cachedSize, int col, IDataView
private static IDataView AddVarLengthColumn<TSrc>(IHostEnvironment env, IDataView idv, string variableSizeVectorColumnName, ColumnType typeSrc)
{
return LambdaColumnMapper.Create(env, "ChangeToVarLength", idv, variableSizeVectorColumnName,
variableSizeVectorColumnName + "_VarLength", typeSrc, new VectorType(typeSrc.ItemType.AsPrimitive),
variableSizeVectorColumnName + "_VarLength", typeSrc, new VectorType((PrimitiveType)typeSrc.ItemType),
(in VBuffer<TSrc> src, ref VBuffer<TSrc> dst) => src.CopyTo(ref dst));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1001,11 +1001,11 @@ protected override IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMa
if (!perInst.Schema.TryGetColumnIndex(schema.Label.Name, out int labelCol))
throw Host.Except("Could not find column '{0}'", schema.Label.Name);
var labelType = perInst.Schema.GetColumnType(labelCol);
if (labelType.IsKey && (!perInst.Schema.HasKeyValues(labelCol, labelType.KeyCount) || labelType.RawKind != DataKind.U4))
if (labelType is KeyType keyType && (!perInst.Schema.HasKeyValues(labelCol, keyType.KeyCount) || labelType.RawKind != DataKind.U4))
{
perInst = LambdaColumnMapper.Create(Host, "ConvertToDouble", perInst, schema.Label.Name,
schema.Label.Name, perInst.Schema.GetColumnType(labelCol), NumberType.R8,
(in uint src, ref double dst) => dst = src == 0 ? double.NaN : src - 1 + (double)labelType.AsKey.Min);
(in uint src, ref double dst) => dst = src == 0 ? double.NaN : src - 1 + (double)keyType.Min);
}

var perInstSchema = perInst.Schema;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ private void CheckInputColumnTypes(Schema schema, out ColumnType labelType, out
var t = schema.GetColumnType(LabelIndex);
if (!t.IsKnownSizeVector || (t.ItemType != NumberType.R4 && t.ItemType != NumberType.R8))
throw Host.Except("Label column '{0}' has type '{1}' but must be a known-size vector of R4 or R8", LabelCol, t);
labelType = new VectorType(t.ItemType.AsPrimitive, t.VectorSize);
labelType = new VectorType((PrimitiveType)t.ItemType, t.VectorSize);
var slotNamesType = new VectorType(TextType.Instance, t.VectorSize);
var builder = new MetadataBuilder();
builder.AddSlotNames(t.VectorSize, CreateSlotNamesGetter(schema, LabelIndex, labelType.VectorSize, "True"));
Expand All @@ -558,7 +558,7 @@ private void CheckInputColumnTypes(Schema schema, out ColumnType labelType, out
t = schema.GetColumnType(ScoreIndex);
if (t.VectorSize == 0 || t.ItemType != NumberType.Float)
throw Host.Except("Score column '{0}' has type '{1}' but must be a known length vector of type R4", ScoreCol, t);
scoreType = new VectorType(t.ItemType.AsPrimitive, t.VectorSize);
scoreType = new VectorType((PrimitiveType)t.ItemType, t.VectorSize);
builder = new MetadataBuilder();
builder.AddSlotNames(t.VectorSize, CreateSlotNamesGetter(schema, ScoreIndex, scoreType.VectorSize, "Predicted"));

Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ private static ISchemaBoundMapper WrapCore<T>(IHostEnvironment env, ISchemaBound
trainSchema.Label.Index, ref value);
};

return MultiClassClassifierScorer.LabelNameBindableMapper.CreateBound<T>(env, (ISchemaBoundRowMapper)mapper, type.AsVector, getter, MetadataUtils.Kinds.TrainingLabelValues, CanWrap);
return MultiClassClassifierScorer.LabelNameBindableMapper.CreateBound<T>(env, (ISchemaBoundRowMapper)mapper, type as VectorType, getter, MetadataUtils.Kinds.TrainingLabelValues, CanWrap);
}

public BinaryClassifierScorer(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ public RowMapper(IHostEnvironment env, BindableMapper parent, RoleMappedSchema s
else
{
_outputSchema = Schema.Create(new FeatureContributionSchema(_env, DefaultColumnNames.FeatureContributions,
new VectorType(NumberType.R4, schema.Feature.Type.AsVector),
new VectorType(NumberType.R4, schema.Feature.Type as VectorType),
InputSchema, InputRoleMappedSchema.Feature.Index));
}

Expand Down
6 changes: 3 additions & 3 deletions src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,9 @@ private LabelNameBindableMapper(IHost host, ModelLoadContext ctx)
ColumnType type;
object value;
_host.CheckDecode(saver.TryLoadTypeAndValue(ctx.Reader.BaseStream, out type, out value));
_host.CheckDecode(type.IsVector);
_type = type as VectorType;
_host.CheckDecode(_type != null);
_host.CheckDecode(value != null);
_type = type.AsVector;
_getter = Utils.MarshalInvoke(DecodeInit<int>, _type.ItemType.RawType, value);
_metadataKind = ctx.Header.ModelVerReadable >= VersionAddedMetadataKind ?
ctx.LoadNonEmptyString() : MetadataUtils.Kinds.SlotNames;
Expand Down Expand Up @@ -493,7 +493,7 @@ public static ISchemaBoundMapper WrapCore<T>(IHostEnvironment env, ISchemaBoundM
trainSchema.Label.Index, ref value);
};

return LabelNameBindableMapper.CreateBound<T>(env, (ISchemaBoundRowMapper)mapper, type.AsVector, getter, MetadataUtils.Kinds.SlotNames, CanWrap);
return LabelNameBindableMapper.CreateBound<T>(env, (ISchemaBoundRowMapper)mapper, type as VectorType, getter, MetadataUtils.Kinds.SlotNames, CanWrap);
}

public MultiClassClassifierScorer(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ private BoundColumn MakeColumn(Schema inputSchema, int iinfo)
hasSlotNames = false;
}

return new BoundColumn(InputSchema, _parent._columns[iinfo], sources, new VectorType(itemType.AsPrimitive, totalSize),
return new BoundColumn(InputSchema, _parent._columns[iinfo], sources, new VectorType((PrimitiveType)itemType, totalSize),
isNormalized, hasSlotNames, hasCategoricals, totalSize, catCount);
}

Expand Down
Loading