From fa4b3fdeaa9b92f6405acf8cbf009af7e9f9c1b5 Mon Sep 17 00:00:00 2001 From: Eric Erhardt Date: Mon, 25 Feb 2019 11:00:47 -0600 Subject: [PATCH 1/3] Make DataViewRowId not act like a number. - Remove it from the NumberDataViewType. - Remove any method/operator that makes it feel like a number. Working towards #2297 --- src/Microsoft.Data.DataView/DataViewRowId.cs | 63 +----- src/Microsoft.Data.DataView/DataViewType.cs | 46 ++++- .../Data/ColumnTypeExtensions.cs | 8 +- src/Microsoft.ML.Data/Data/DataViewUtils.cs | 4 +- .../Transforms/RowShufflingTransformer.cs | 2 +- src/Microsoft.ML.Parquet/ParquetLoader.cs | 2 +- .../StaticSchemaShape.cs | 4 +- .../MissingValueReplacingUtils.cs | 189 +++++++++++------- .../ProduceIdTransform.cs | 2 +- .../UnitTests/ColumnTypes.cs | 2 +- .../Transformers/HashTests.cs | 2 +- 11 files changed, 167 insertions(+), 157 deletions(-) diff --git a/src/Microsoft.Data.DataView/DataViewRowId.cs b/src/Microsoft.Data.DataView/DataViewRowId.cs index 586ad62915..87544279d5 100644 --- a/src/Microsoft.Data.DataView/DataViewRowId.cs +++ b/src/Microsoft.Data.DataView/DataViewRowId.cs @@ -9,7 +9,7 @@ namespace Microsoft.Data.DataView { /// - /// A structure serving as a sixteen-byte unsigned integer. It is used as the row id of . + /// A structure serving as the identifier of a row of . /// For datasets with millions of records, those IDs need to be unique, therefore the need for such a large structure to hold the values. /// Those Ids are derived from other Ids of the previous components of the pipelines, and dividing the structure in two: high order and low order of bits, /// and reduces the changes of those collisions even further. @@ -53,70 +53,13 @@ public bool Equals(DataViewRowId other) public override bool Equals(object obj) { - if (obj != null && obj is DataViewRowId) + if (obj is DataViewRowId other) { - var item = (DataViewRowId)obj; - return Equals(item); + return Equals(other); } return false; } - public static DataViewRowId operator +(DataViewRowId first, ulong second) - { - ulong resHi = first.High; - ulong resLo = first.Low + second; - if (resLo < second) - resHi++; - return new DataViewRowId(resLo, resHi); - } - - public static DataViewRowId operator -(DataViewRowId first, ulong second) - { - ulong resHi = first.High; - ulong resLo = first.Low - second; - if (resLo > first.Low) - resHi--; - return new DataViewRowId(resLo, resHi); - } - - public static bool operator ==(DataViewRowId first, ulong second) - { - return first.High == 0 && first.Low == second; - } - - public static bool operator !=(DataViewRowId first, ulong second) - { - return !(first == second); - } - - public static bool operator <(DataViewRowId first, ulong second) - { - return first.High == 0 && first.Low < second; - } - - public static bool operator >(DataViewRowId first, ulong second) - { - return first.High > 0 || first.Low > second; - } - - public static bool operator <=(DataViewRowId first, ulong second) - { - return first.High == 0 && first.Low <= second; - } - - public static bool operator >=(DataViewRowId first, ulong second) - { - return first.High > 0 || first.Low >= second; - } - - public static explicit operator double(DataViewRowId x) - { - // REVIEW: The 64-bit JIT has a bug where rounding might be not quite - // correct when converting a ulong to double with the high bit set. Should we - // care and compensate? See the DoubleParser code for a work-around. - return x.High * ((double)(1UL << 32) * (1UL << 32)) + x.Low; - } - public override int GetHashCode() { return (int)( diff --git a/src/Microsoft.Data.DataView/DataViewType.cs b/src/Microsoft.Data.DataView/DataViewType.cs index 98b458781e..8440ad4cef 100644 --- a/src/Microsoft.Data.DataView/DataViewType.cs +++ b/src/Microsoft.Data.DataView/DataViewType.cs @@ -199,17 +199,6 @@ public static NumberDataViewType UInt64 } } - private static volatile NumberDataViewType _instDataViewRowId; - public static NumberDataViewType DataViewRowId - { - get - { - return _instDataViewRowId ?? - Interlocked.CompareExchange(ref _instDataViewRowId, new NumberDataViewType(typeof(DataViewRowId), "UG"), null) ?? - _instDataViewRowId; - } - } - private static volatile NumberDataViewType _instSingle; public static NumberDataViewType Single { @@ -243,6 +232,41 @@ public override bool Equals(DataViewType other) public override string ToString() => _name; } + /// + /// The DataViewRowId type. + /// + public sealed class RowIdDataViewType : PrimitiveDataViewType + { + private static volatile RowIdDataViewType _instance; + public static RowIdDataViewType Instance + { + get + { + return _instance ?? + Interlocked.CompareExchange(ref _instance, new RowIdDataViewType(), null) ?? + _instance; + } + } + + private RowIdDataViewType() + : base(typeof(DataViewRowId)) + { + } + + public override bool Equals(DataViewType other) + { + if (other == this) + return true; + Debug.Assert(!(other is RowIdDataViewType)); + return false; + } + + public override string ToString() + { + return "DataViewRowId"; + } + } + /// /// The standard boolean type. /// diff --git a/src/Microsoft.ML.Core/Data/ColumnTypeExtensions.cs b/src/Microsoft.ML.Core/Data/ColumnTypeExtensions.cs index 1213561385..6863d381a3 100644 --- a/src/Microsoft.ML.Core/Data/ColumnTypeExtensions.cs +++ b/src/Microsoft.ML.Core/Data/ColumnTypeExtensions.cs @@ -103,6 +103,8 @@ public static PrimitiveDataViewType PrimitiveTypeFromType(Type type) return DateTimeDataViewType.Instance; if (type == typeof(DateTimeOffset)) return DateTimeOffsetDataViewType.Instance; + if (type == typeof(DataViewRowId)) + return RowIdDataViewType.Instance; return NumberTypeFromType(type); } @@ -118,6 +120,8 @@ public static PrimitiveDataViewType PrimitiveTypeFromKind(InternalDataKind kind) return DateTimeDataViewType.Instance; if (kind == InternalDataKind.DZ) return DateTimeOffsetDataViewType.Instance; + if (kind == InternalDataKind.UG) + return RowIdDataViewType.Instance; return NumberTypeFromKind(kind); } @@ -131,7 +135,7 @@ public static NumberDataViewType NumberTypeFromType(Type type) throw new InvalidOperationException($"Bad type in {nameof(ColumnTypeExtensions)}.{nameof(NumberTypeFromType)}: {type}"); } - public static NumberDataViewType NumberTypeFromKind(InternalDataKind kind) + private static NumberDataViewType NumberTypeFromKind(InternalDataKind kind) { switch (kind) { @@ -155,8 +159,6 @@ public static NumberDataViewType NumberTypeFromKind(InternalDataKind kind) return NumberDataViewType.Single; case InternalDataKind.R8: return NumberDataViewType.Double; - case InternalDataKind.UG: - return NumberDataViewType.DataViewRowId; } Contracts.Assert(false); diff --git a/src/Microsoft.ML.Data/Data/DataViewUtils.cs b/src/Microsoft.ML.Data/Data/DataViewUtils.cs index 001fbe360f..021314c881 100644 --- a/src/Microsoft.ML.Data/Data/DataViewUtils.cs +++ b/src/Microsoft.ML.Data/Data/DataViewUtils.cs @@ -357,7 +357,7 @@ private static DataViewRowCursor ConsolidateCore(IChannelProvider provider, Data outPipes[i] = OutPipe.Create(type, pool); } int idIdx = activeToCol.Length + (int)ExtraIndex.Id; - outPipes[idIdx] = OutPipe.Create(NumberDataViewType.DataViewRowId, GetPool(NumberDataViewType.DataViewRowId, ourPools, idIdx)); + outPipes[idIdx] = OutPipe.Create(RowIdDataViewType.Instance, GetPool(RowIdDataViewType.Instance, ourPools, idIdx)); // Create the structures to synchronize between the workers and the consumer. const int toConsumeBound = 4; @@ -553,7 +553,7 @@ private DataViewRowCursor[] SplitCore(IChannelProvider ch, DataViewRowCursor inp int idIdx = activeToCol.Length + (int)ExtraIndex.Id; inPipes[idIdx] = CreateIdInPipe(input); for (int i = 0; i < cthd; ++i) - outPipes[i][idIdx] = inPipes[idIdx].CreateOutPipe(NumberDataViewType.DataViewRowId); + outPipes[i][idIdx] = inPipes[idIdx].CreateOutPipe(RowIdDataViewType.Instance); var toConsume = new BlockingCollection(toConsumeBound); var batchColumnPool = new MadeObjectPool(() => new BatchColumn[inPipes.Length]); diff --git a/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs b/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs index 70b69e574b..8dd7b7ba19 100644 --- a/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs @@ -529,7 +529,7 @@ public Cursor(IChannelProvider provider, int poolRows, DataViewRowCursor input, input.Schema[c].Type, RowCursorUtils.GetGetterAsDelegate(input, c)); _getters[ia] = CreateGetterDelegate(c); } - var idPipe = _pipes[numActive + (int)ExtraIndex.Id] = ShufflePipe.Create(_pipeIndices.Length, NumberDataViewType.DataViewRowId, input.GetIdGetter()); + var idPipe = _pipes[numActive + (int)ExtraIndex.Id] = ShufflePipe.Create(_pipeIndices.Length, RowIdDataViewType.Instance, input.GetIdGetter()); _idGetter = CreateGetterDelegate(idPipe); // Initially, after the preamble to MoveNextCore, we want: // liveCount=0, deadCount=0, circularIndex=0. So we set these diff --git a/src/Microsoft.ML.Parquet/ParquetLoader.cs b/src/Microsoft.ML.Parquet/ParquetLoader.cs index fdd10d0700..85336ad6f2 100644 --- a/src/Microsoft.ML.Parquet/ParquetLoader.cs +++ b/src/Microsoft.ML.Parquet/ParquetLoader.cs @@ -350,7 +350,7 @@ private DataViewType ConvertFieldType(DataType parquetType) case DataType.Int64: return NumberDataViewType.Int64; case DataType.Int96: - return NumberDataViewType.DataViewRowId; + return RowIdDataViewType.Instance; case DataType.ByteArray: return new VectorType(NumberDataViewType.Byte); case DataType.String: diff --git a/src/Microsoft.ML.StaticPipe/StaticSchemaShape.cs b/src/Microsoft.ML.StaticPipe/StaticSchemaShape.cs index e830b29934..d251aba49c 100644 --- a/src/Microsoft.ML.StaticPipe/StaticSchemaShape.cs +++ b/src/Microsoft.ML.StaticPipe/StaticSchemaShape.cs @@ -166,7 +166,7 @@ private static Type GetTypeOrNull(SchemaShape.Column col) if (physType != null && ( pt == NumberDataViewType.SByte || pt == NumberDataViewType.Int16 || pt == NumberDataViewType.Int32 || pt == NumberDataViewType.Int32 || pt == NumberDataViewType.Byte || pt == NumberDataViewType.UInt16 || pt == NumberDataViewType.UInt32 || pt == NumberDataViewType.UInt32 || - pt == NumberDataViewType.Single || pt == NumberDataViewType.Double || pt == NumberDataViewType.DataViewRowId || pt == BooleanDataViewType.Instance || + pt == NumberDataViewType.Single || pt == NumberDataViewType.Double || pt == RowIdDataViewType.Instance || pt == BooleanDataViewType.Instance || pt == DateTimeDataViewType.Instance || pt == DateTimeOffsetDataViewType.Instance || pt == TimeSpanDataViewType.Instance || pt == TextDataViewType.Instance)) { @@ -311,7 +311,7 @@ private static Type GetTypeOrNull(DataViewSchema.Column col) if (physType != null && ( pt == NumberDataViewType.SByte || pt == NumberDataViewType.Int16 || pt == NumberDataViewType.Int32 || pt == NumberDataViewType.Int64 || pt == NumberDataViewType.Byte || pt == NumberDataViewType.UInt16 || pt == NumberDataViewType.UInt32 || pt == NumberDataViewType.UInt64 || - pt == NumberDataViewType.Single || pt == NumberDataViewType.Double || pt == NumberDataViewType.DataViewRowId || pt == BooleanDataViewType.Instance || + pt == NumberDataViewType.Single || pt == NumberDataViewType.Double || pt == RowIdDataViewType.Instance || pt == BooleanDataViewType.Instance || pt == DateTimeDataViewType.Instance || pt == DateTimeOffsetDataViewType.Instance || pt == TimeSpanDataViewType.Instance || pt == TextDataViewType.Instance)) { diff --git a/src/Microsoft.ML.Transforms/MissingValueReplacingUtils.cs b/src/Microsoft.ML.Transforms/MissingValueReplacingUtils.cs index bc6024b225..76f4552605 100644 --- a/src/Microsoft.ML.Transforms/MissingValueReplacingUtils.cs +++ b/src/Microsoft.ML.Transforms/MissingValueReplacingUtils.cs @@ -80,6 +80,47 @@ private static StatAggregator CreateStatAggregator(IChannel ch, DataViewType typ "assigned in NAReplaceTransform.", kind, type); } + private static DataViewRowId Add(DataViewRowId left, ulong right) + { + ulong resHi = left.High; + ulong resLo = left.Low + right; + if (resLo < right) + resHi++; + return new DataViewRowId(resLo, resHi); + } + + private static DataViewRowId Subtract(DataViewRowId left, ulong right) + { + ulong resHi = left.High; + ulong resLo = left.Low - right; + if (resLo > left.Low) + resHi--; + return new DataViewRowId(resLo, resHi); + } + + private static bool Equals(DataViewRowId left, ulong right) + { + return left.High == 0 && left.Low == right; + } + + private static bool GreaterThanOrEqual(DataViewRowId left, ulong right) + { + return left.High > 0 || left.Low >= right; + } + + private static bool GreaterThan(DataViewRowId left, ulong right) + { + return left.High > 0 || left.Low > right; + } + + private static double ToDouble(DataViewRowId value) + { + // REVIEW: The 64-bit JIT has a bug where rounding might be not quite + // correct when converting a ulong to double with the high bit set. Should we + // care and compensate? See the DoubleParser code for a work-around. + return value.High * ((double)(1UL << 32) * (1UL << 32)) + value.Low; + } + /// /// The base class for stat aggregators for imputing mean, min, and max for the NAReplaceTransform. /// @@ -161,7 +202,7 @@ protected sealed override void ProcessRow(in VBuffer src) for (int slot = 0; slot < srcCount; slot++) ProcessValue(in srcValues[slot]); - _valueCount = _valueCount + (ulong)src.Length; + _valueCount = Add(_valueCount, (ulong)src.Length); } protected abstract void ProcessValue(in TItem val); @@ -312,11 +353,11 @@ private struct MeanStatDouble // The number of non-zero (finite) values processed. private long _cnz; // The current mean estimate for the _cnz values we've processed. - private Double _cur; + private double _cur; - public void Update(Double val) + public void Update(double val) { - Contracts.Assert(Double.MinValue <= _cur && _cur <= Double.MaxValue); + Contracts.Assert(double.MinValue <= _cur && _cur <= double.MaxValue); if (val == 0) return; @@ -335,12 +376,12 @@ public void Update(Double val) else _cur += (val - _cur) / _cnz; - Contracts.Assert(Double.MinValue <= _cur && _cur <= Double.MaxValue); + Contracts.Assert(double.MinValue <= _cur && _cur <= double.MaxValue); } - public Double GetCurrentValue(IChannel ch, long count) + public double GetCurrentValue(IChannel ch, long count) { - Contracts.Assert(Double.MinValue <= _cur && _cur <= Double.MaxValue); + Contracts.Assert(double.MinValue <= _cur && _cur <= double.MaxValue); Contracts.Assert(_cnz >= 0 && _cna >= 0); Contracts.Assert(count >= _cna); Contracts.Assert(count - _cna >= _cnz); @@ -353,28 +394,28 @@ public Double GetCurrentValue(IChannel ch, long count) } // Fold in the zeros. - Double stat = _cur * ((Double)_cnz / (count - _cna)); - Contracts.Assert(Double.MinValue <= stat && stat <= Double.MaxValue); + double stat = _cur * ((double)_cnz / (count - _cna)); + Contracts.Assert(double.MinValue <= stat && stat <= double.MaxValue); return stat; } - public Double GetCurrentValue(IChannel ch, DataViewRowId count) + public double GetCurrentValue(IChannel ch, DataViewRowId count) { - Contracts.Assert(Double.MinValue <= _cur && _cur <= Double.MaxValue); + Contracts.Assert(double.MinValue <= _cur && _cur <= double.MaxValue); Contracts.Assert(_cnz >= 0 && _cna >= 0); Contracts.Assert(count.High != 0 || count.Low >= (ulong)_cna); // If all values in the column are NAs, emit a warning and return 0. // Is this what we want to do or should an error be thrown? - if (count == (ulong)_cna) + if (Equals(count, (ulong)_cna)) { ch.Warning("All values in this column are NAs, using default value for imputation"); return 0; } // Fold in the zeros. - Double stat = _cur * ((Double)_cnz / (Double)(count - (ulong)_cna)); - Contracts.Assert(Double.MinValue <= stat && stat <= Double.MaxValue); + double stat = _cur * ((double)_cnz / ToDouble(Subtract(count, (ulong)_cna))); + Contracts.Assert(double.MinValue <= stat && stat <= double.MaxValue); return stat; } } @@ -462,20 +503,20 @@ public long GetCurrentValue(IChannel ch, long count, long valMax) public long GetCurrentValue(IChannel ch, DataViewRowId count, long valMax) { AssertValid(valMax); - Contracts.Assert(count >= (ulong)_cna); + Contracts.Assert(GreaterThanOrEqual(count, (ulong)_cna)); // If the sum is zero, return zero. if ((_sumHi | _sumLo) == 0) { // If all values in a given column are NAs issue a warning. - if (count == (ulong)_cna) + if (Equals(count, (ulong)_cna)) ch.Warning("All values in this column are NAs, using default value for imputation"); return 0; } - Contracts.Assert(count > (ulong)_cna); - count -= (ulong)_cna; - Contracts.Assert(count > 0); + Contracts.Assert(GreaterThan(count, (ulong)_cna)); + count = Subtract(count, (ulong)_cna); + Contracts.Assert(GreaterThan(count, 0)); ulong sumHi = _sumHi; ulong sumLo = _sumLo; @@ -495,7 +536,7 @@ public long GetCurrentValue(IChannel ch, DataViewRowId count, long valMax) // a ulong, so the absolute value of the sum can't possibly be so large that sumHi // reaches or exceeds count. This assert implies that the Div part of the DivRound // call won't throw. - Contracts.Assert(count > sumHi); + Contracts.Assert(GreaterThan(count, sumHi)); ulong res = IntUtils.DivRound(sumLo, sumHi, count.Low, count.High); Contracts.Assert(0 <= res && res <= (ulong)valMax); @@ -508,54 +549,54 @@ private static class R4 // Utilizes MeanStatDouble for the mean aggregators, a struct that holds _stat as a double, despite the fact that its // value should always be within the range of a valid Single after processing each value as it is representative of the // mean of a set of Single values. Conversion to Single happens in GetStat. - public sealed class MeanAggregatorOne : StatAggregator + public sealed class MeanAggregatorOne : StatAggregator { public MeanAggregatorOne(IChannel ch, DataViewRowCursor cursor, int col) : base(ch, cursor, col) { } - protected override void ProcessRow(in Single val) + protected override void ProcessRow(in float val) { Stat.Update(val); } public override object GetStat() { - Double val = Stat.GetCurrentValue(Ch, RowCount); - Ch.Assert(Single.MinValue <= val && val <= Single.MaxValue); - return (Single)val; + double val = Stat.GetCurrentValue(Ch, RowCount); + Ch.Assert(float.MinValue <= val && val <= float.MaxValue); + return (float)val; } } - public sealed class MeanAggregatorAcrossSlots : StatAggregatorAcrossSlots + public sealed class MeanAggregatorAcrossSlots : StatAggregatorAcrossSlots { public MeanAggregatorAcrossSlots(IChannel ch, DataViewRowCursor cursor, int col) : base(ch, cursor, col) { } - protected override void ProcessValue(in Single val) + protected override void ProcessValue(in float val) { Stat.Update(val); } public override object GetStat() { - Double val = Stat.GetCurrentValue(Ch, ValueCount); - Ch.Assert(Single.MinValue <= val && val <= Single.MaxValue); - return (Single)val; + double val = Stat.GetCurrentValue(Ch, ValueCount); + Ch.Assert(float.MinValue <= val && val <= float.MaxValue); + return (float)val; } } - public sealed class MeanAggregatorBySlot : StatAggregatorBySlot + public sealed class MeanAggregatorBySlot : StatAggregatorBySlot { public MeanAggregatorBySlot(IChannel ch, VectorType type, DataViewRowCursor cursor, int col) : base(ch, type, cursor, col) { } - protected override void ProcessValue(in Single val, int slot) + protected override void ProcessValue(in float val, int slot) { Ch.Assert(0 <= slot && slot < Stat.Length); Stat[slot].Update(val); @@ -563,53 +604,53 @@ protected override void ProcessValue(in Single val, int slot) public override object GetStat() { - Single[] stat = new Single[Stat.Length]; + float[] stat = new float[Stat.Length]; for (int slot = 0; slot < stat.Length; slot++) { - Double val = Stat[slot].GetCurrentValue(Ch, RowCount); - Ch.Assert(Single.MinValue <= val && val <= Single.MaxValue); - stat[slot] = (Single)val; + double val = Stat[slot].GetCurrentValue(Ch, RowCount); + Ch.Assert(float.MinValue <= val && val <= float.MaxValue); + stat[slot] = (float)val; } return stat; } } - public sealed class MinMaxAggregatorOne : MinMaxAggregatorOne + public sealed class MinMaxAggregatorOne : MinMaxAggregatorOne { public MinMaxAggregatorOne(IChannel ch, DataViewRowCursor cursor, int col, bool returnMax) : base(ch, cursor, col, returnMax) { - Stat = ReturnMax ? Single.NegativeInfinity : Single.PositiveInfinity; + Stat = ReturnMax ? float.NegativeInfinity : float.PositiveInfinity; } - protected override void ProcessValueMin(in Single val) + protected override void ProcessValueMin(in float val) { if (val < Stat) Stat = val; } - protected override void ProcessValueMax(in Single val) + protected override void ProcessValueMax(in float val) { if (val > Stat) Stat = val; } } - public sealed class MinMaxAggregatorAcrossSlots : MinMaxAggregatorAcrossSlots + public sealed class MinMaxAggregatorAcrossSlots : MinMaxAggregatorAcrossSlots { public MinMaxAggregatorAcrossSlots(IChannel ch, DataViewRowCursor cursor, int col, bool returnMax) : base(ch, cursor, col, returnMax) { - Stat = ReturnMax ? Single.NegativeInfinity : Single.PositiveInfinity; + Stat = ReturnMax ? float.NegativeInfinity : float.PositiveInfinity; } - protected override void ProcessValueMin(in Single val) + protected override void ProcessValueMin(in float val) { if (val < Stat) Stat = val; } - protected override void ProcessValueMax(in Single val) + protected override void ProcessValueMax(in float val) { if (val > Stat) Stat = val; @@ -618,33 +659,33 @@ protected override void ProcessValueMax(in Single val) public override object GetStat() { // If sparsity occurred, fold in a zero. - if (ValueCount > (ulong)ValuesProcessed) + if (GreaterThan(ValueCount, (ulong)ValuesProcessed)) { - Single def = 0; + float def = 0; ProcValueDelegate(in def); } - return (Single)Stat; + return (float)Stat; } } - public sealed class MinMaxAggregatorBySlot : MinMaxAggregatorBySlot + public sealed class MinMaxAggregatorBySlot : MinMaxAggregatorBySlot { public MinMaxAggregatorBySlot(IChannel ch, VectorType type, DataViewRowCursor cursor, int col, bool returnMax) : base(ch, type, cursor, col, returnMax) { - Single bound = ReturnMax ? Single.NegativeInfinity : Single.PositiveInfinity; + float bound = ReturnMax ? float.NegativeInfinity : float.PositiveInfinity; for (int i = 0; i < Stat.Length; i++) Stat[i] = bound; } - protected override void ProcessValueMin(in Single val, int slot) + protected override void ProcessValueMin(in float val, int slot) { Ch.Assert(0 <= slot && slot < Stat.Length); if (val < Stat[slot]) Stat[slot] = val; } - protected override void ProcessValueMax(in Single val, int slot) + protected override void ProcessValueMax(in float val, int slot) { Ch.Assert(0 <= slot && slot < Stat.Length); if (val > Stat[slot]) @@ -658,7 +699,7 @@ public override object GetStat() { if (GetValuesProcessed(slot) < RowCount) { - Single def = 0; + float def = 0; ProcValueDelegate(in def, slot); } } @@ -669,14 +710,14 @@ public override object GetStat() private static class R8 { - public sealed class MeanAggregatorOne : StatAggregator + public sealed class MeanAggregatorOne : StatAggregator { public MeanAggregatorOne(IChannel ch, DataViewRowCursor cursor, int col) : base(ch, cursor, col) { } - protected override void ProcessRow(in Double val) + protected override void ProcessRow(in double val) { Stat.Update(val); } @@ -687,14 +728,14 @@ public override object GetStat() } } - public sealed class MeanAggregatorAcrossSlots : StatAggregatorAcrossSlots + public sealed class MeanAggregatorAcrossSlots : StatAggregatorAcrossSlots { public MeanAggregatorAcrossSlots(IChannel ch, DataViewRowCursor cursor, int col) : base(ch, cursor, col) { } - protected override void ProcessValue(in Double val) + protected override void ProcessValue(in double val) { Stat.Update(val); } @@ -705,14 +746,14 @@ public override object GetStat() } } - public sealed class MeanAggregatorBySlot : StatAggregatorBySlot + public sealed class MeanAggregatorBySlot : StatAggregatorBySlot { public MeanAggregatorBySlot(IChannel ch, VectorType type, DataViewRowCursor cursor, int col) : base(ch, type, cursor, col) { } - protected override void ProcessValue(in Double val, int slot) + protected override void ProcessValue(in double val, int slot) { Ch.Assert(0 <= slot && slot < Stat.Length); Stat[slot].Update(val); @@ -720,49 +761,49 @@ protected override void ProcessValue(in Double val, int slot) public override object GetStat() { - Double[] stat = new Double[Stat.Length]; + double[] stat = new double[Stat.Length]; for (int slot = 0; slot < stat.Length; slot++) stat[slot] = Stat[slot].GetCurrentValue(Ch, RowCount); return stat; } } - public sealed class MinMaxAggregatorOne : MinMaxAggregatorOne + public sealed class MinMaxAggregatorOne : MinMaxAggregatorOne { public MinMaxAggregatorOne(IChannel ch, DataViewRowCursor cursor, int col, bool returnMax) : base(ch, cursor, col, returnMax) { - Stat = ReturnMax ? Double.NegativeInfinity : Double.PositiveInfinity; + Stat = ReturnMax ? double.NegativeInfinity : double.PositiveInfinity; } - protected override void ProcessValueMin(in Double val) + protected override void ProcessValueMin(in double val) { if (val < Stat) Stat = val; } - protected override void ProcessValueMax(in Double val) + protected override void ProcessValueMax(in double val) { if (val > Stat) Stat = val; } } - public sealed class MinMaxAggregatorAcrossSlots : MinMaxAggregatorAcrossSlots + public sealed class MinMaxAggregatorAcrossSlots : MinMaxAggregatorAcrossSlots { public MinMaxAggregatorAcrossSlots(IChannel ch, DataViewRowCursor cursor, int col, bool returnMax) : base(ch, cursor, col, returnMax) { - Stat = ReturnMax ? Double.NegativeInfinity : Double.PositiveInfinity; + Stat = ReturnMax ? double.NegativeInfinity : double.PositiveInfinity; } - protected override void ProcessValueMin(in Double val) + protected override void ProcessValueMin(in double val) { if (val < Stat) Stat = val; } - protected override void ProcessValueMax(in Double val) + protected override void ProcessValueMax(in double val) { if (val > Stat) Stat = val; @@ -771,26 +812,26 @@ protected override void ProcessValueMax(in Double val) public override object GetStat() { // If sparsity occurred, fold in a zero. - if (ValueCount > (ulong)ValuesProcessed) + if (GreaterThan(ValueCount, (ulong)ValuesProcessed)) { - Double def = 0; + double def = 0; ProcValueDelegate(in def); } return Stat; } } - public sealed class MinMaxAggregatorBySlot : MinMaxAggregatorBySlot + public sealed class MinMaxAggregatorBySlot : MinMaxAggregatorBySlot { public MinMaxAggregatorBySlot(IChannel ch, VectorType type, DataViewRowCursor cursor, int col, bool returnMax) : base(ch, type, cursor, col, returnMax) { - Double bound = ReturnMax ? Double.MinValue : Double.MaxValue; + double bound = ReturnMax ? double.MinValue : double.MaxValue; for (int i = 0; i < Stat.Length; i++) Stat[i] = bound; } - protected override void ProcessValueMin(in Double val, int slot) + protected override void ProcessValueMin(in double val, int slot) { Ch.Assert(0 <= slot && slot < Stat.Length); if (FloatUtils.IsFinite(val)) @@ -800,7 +841,7 @@ protected override void ProcessValueMin(in Double val, int slot) } } - protected override void ProcessValueMax(in Double val, int slot) + protected override void ProcessValueMax(in double val, int slot) { Ch.Assert(0 <= slot && slot < Stat.Length); if (FloatUtils.IsFinite(val)) @@ -817,7 +858,7 @@ public override object GetStat() { if (GetValuesProcessed(slot) < RowCount) { - Double def = 0; + double def = 0; ProcValueDelegate(in def, slot); } } diff --git a/src/Microsoft.ML.Transforms/ProduceIdTransform.cs b/src/Microsoft.ML.Transforms/ProduceIdTransform.cs index 84d568b5f3..071213b212 100644 --- a/src/Microsoft.ML.Transforms/ProduceIdTransform.cs +++ b/src/Microsoft.ML.Transforms/ProduceIdTransform.cs @@ -46,7 +46,7 @@ public Bindings(DataViewSchema input, bool user, string name) protected override DataViewType GetColumnTypeCore(int iinfo) { Contracts.Assert(iinfo == 0); - return NumberDataViewType.DataViewRowId; + return RowIdDataViewType.Instance; } public static Bindings Create(ModelLoadContext ctx, DataViewSchema input) diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/ColumnTypes.cs b/test/Microsoft.ML.Core.Tests/UnitTests/ColumnTypes.cs index 6f173a8c11..428657ab21 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/ColumnTypes.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/ColumnTypes.cs @@ -18,7 +18,7 @@ public void TestEqualAndGetHashCode() // add PrimitiveTypes, KeyType & corresponding VectorTypes VectorType tmp1, tmp2; var types = new PrimitiveDataViewType[] { NumberDataViewType.SByte, NumberDataViewType.Int16, NumberDataViewType.Int32, NumberDataViewType.Int64, - NumberDataViewType.Byte, NumberDataViewType.UInt16, NumberDataViewType.UInt32, NumberDataViewType.UInt64, NumberDataViewType.DataViewRowId, + NumberDataViewType.Byte, NumberDataViewType.UInt16, NumberDataViewType.UInt32, NumberDataViewType.UInt64, RowIdDataViewType.Instance, TextDataViewType.Instance, BooleanDataViewType.Instance, DateTimeDataViewType.Instance, DateTimeOffsetDataViewType.Instance, TimeSpanDataViewType.Instance }; foreach (var type in types) diff --git a/test/Microsoft.ML.Tests/Transformers/HashTests.cs b/test/Microsoft.ML.Tests/Transformers/HashTests.cs index eac61568a1..2d805f762d 100644 --- a/test/Microsoft.ML.Tests/Transformers/HashTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/HashTests.cs @@ -247,7 +247,7 @@ private void HashTestPositiveIntegerCore(ulong value, uint expected, uint expect HashTestCore(value, NumberDataViewType.UInt64, expected, expectedOrdered, expectedOrdered3); HashTestCore((ulong)value, new KeyType(typeof(ulong), int.MaxValue - 1), eKey, eoKey, e3Key); - HashTestCore(new DataViewRowId(value, 0), NumberDataViewType.DataViewRowId, expected, expectedOrdered, expectedOrdered3); + HashTestCore(new DataViewRowId(value, 0), RowIdDataViewType.Instance, expected, expectedOrdered, expectedOrdered3); // Next let's check signed numbers. From f968de52388f16304c03bd33d7b5d91338f3ce2f Mon Sep 17 00:00:00 2001 From: Eric Erhardt Date: Mon, 25 Feb 2019 11:58:52 -0600 Subject: [PATCH 2/3] Fix up broken tests. --- src/Microsoft.ML.Core/Data/ColumnTypeExtensions.cs | 3 ++- src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.ML.Core/Data/ColumnTypeExtensions.cs b/src/Microsoft.ML.Core/Data/ColumnTypeExtensions.cs index 6863d381a3..67e6434d82 100644 --- a/src/Microsoft.ML.Core/Data/ColumnTypeExtensions.cs +++ b/src/Microsoft.ML.Core/Data/ColumnTypeExtensions.cs @@ -19,7 +19,8 @@ internal static class ColumnTypeExtensions /// public static bool IsStandardScalar(this DataViewType columnType) => (columnType is NumberDataViewType) || (columnType is TextDataViewType) || (columnType is BooleanDataViewType) || - (columnType is TimeSpanDataViewType) || (columnType is DateTimeDataViewType) || (columnType is DateTimeOffsetDataViewType); + (columnType is RowIdDataViewType) || (columnType is TimeSpanDataViewType) || + (columnType is DateTimeDataViewType) || (columnType is DateTimeOffsetDataViewType); /// /// Zero return means it's not a key type. diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs index a9b61880e5..150b62caab 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs @@ -159,7 +159,9 @@ private sealed class UnsafeTypeCodec : SimpleCodec where T : struct // Throws an exception if T is neither a TimeSpan nor a NumberType. private static DataViewType UnsafeColumnType(Type type) { - return type == typeof(TimeSpan) ? (DataViewType)TimeSpanDataViewType.Instance : ColumnTypeExtensions.NumberTypeFromType(type); + return type == typeof(TimeSpan) ? TimeSpanDataViewType.Instance : + type == typeof(DataViewRowId) ? (DataViewType)RowIdDataViewType.Instance : + ColumnTypeExtensions.NumberTypeFromType(type); } public UnsafeTypeCodec(CodecFactory factory) From 3c3efc0b63c079c762e797f6390cb9797f1c76bc Mon Sep 17 00:00:00 2001 From: Eric Erhardt Date: Mon, 25 Feb 2019 12:27:11 -0600 Subject: [PATCH 3/3] Fix one last failing test. --- src/Microsoft.ML.Data/Transforms/Hashing.cs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/Hashing.cs b/src/Microsoft.ML.Data/Transforms/Hashing.cs index b769cf9649..103dada2da 100644 --- a/src/Microsoft.ML.Data/Transforms/Hashing.cs +++ b/src/Microsoft.ML.Data/Transforms/Hashing.cs @@ -1213,10 +1213,11 @@ internal void Save(ModelSaveContext ctx) internal static bool IsColumnTypeValid(DataViewType type) { var itemType = type.GetItemType(); - return itemType is TextDataViewType || itemType is KeyType || itemType is NumberDataViewType || itemType is BooleanDataViewType; + return itemType is TextDataViewType || itemType is KeyType || itemType is NumberDataViewType || + itemType is BooleanDataViewType || itemType is RowIdDataViewType; } - internal const string ExpectedColumnType = "Expected Text, Key, numeric or Boolean item type"; + internal const string ExpectedColumnType = "Expected Text, Key, numeric, Boolean or DataViewRowId item type"; /// /// Initializes a new instance of .