diff --git a/src/Microsoft.ML.Api/ApiUtils.cs b/src/Microsoft.ML.Api/ApiUtils.cs index 5b5936b5a8..4b7daf0a74 100644 --- a/src/Microsoft.ML.Api/ApiUtils.cs +++ b/src/Microsoft.ML.Api/ApiUtils.cs @@ -6,7 +6,6 @@ using System.Reflection; using System.Reflection.Emit; using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Utilities; namespace Microsoft.ML.Runtime.Api { @@ -19,11 +18,12 @@ internal static class ApiUtils private static OpCode GetAssignmentOpCode(Type t) { // REVIEW: This should be a Dictionary<Type, OpCode> based solution. - // DvTexts, strings, arrays, and VBuffers. + // DvTypes, strings, arrays, all nullable types, VBuffers and UInt128. if (t == typeof(DvInt8) || t == typeof(DvInt4) || t == typeof(DvInt2) || t == typeof(DvInt1) || - t == typeof(DvBool) || t==typeof(bool?) || t == typeof(DvText) || t == typeof(string) || t.IsArray || - (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(VBuffer<>)) || t == typeof(DvDateTime) || - t == typeof(DvDateTimeZone) || t == typeof(DvTimeSpan) || t == typeof(UInt128)) + t == typeof(DvBool) || t == typeof(DvText) || t == typeof(string) || t.IsArray || + (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(VBuffer<>)) || + (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(Nullable<>)) || + t == typeof(DvDateTime) || t == typeof(DvDateTimeZone) || t == typeof(DvTimeSpan) || t == typeof(UInt128)) { return OpCodes.Stobj; } diff --git a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs index f185dc6c0b..6ecff5b204 100644 --- a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs @@ -119,7 +119,7 @@ private Delegate CreateGetter(int index) var column = DataView._schema.SchemaDefn.Columns[index]; var outputType = column.IsComputed ? column.ReturnType : column.FieldInfo.FieldType; - + var genericType = outputType; Func<int, Delegate> del; if (outputType.IsArray) @@ -129,11 +129,66 @@ private Delegate CreateGetter(int index) if (outputType.GetElementType() == typeof(string)) { Ch.Assert(colType.ItemType.IsText); - return CreateStringArrayToVBufferGetter(index); + return CreateConvertingArrayGetterDelegate<String, DvText>(index, x => x == null ? DvText.NA : new DvText(x)); + } + else if (outputType.GetElementType() == typeof(int)) + { + Ch.Assert(colType.ItemType == NumberType.I4); + return CreateConvertingArrayGetterDelegate<int, DvInt4>(index, x => x); + } + else if (outputType.GetElementType() == typeof(int?)) + { + Ch.Assert(colType.ItemType == NumberType.I4); + return CreateConvertingArrayGetterDelegate<int?, DvInt4>(index, x => x ?? DvInt4.NA); + } + else if (outputType.GetElementType() == typeof(long)) + { + Ch.Assert(colType.ItemType == NumberType.I8); + return CreateConvertingArrayGetterDelegate<long, DvInt8>(index, x => x); + } + else if (outputType.GetElementType() == typeof(long?)) + { + Ch.Assert(colType.ItemType == NumberType.I8); + return CreateConvertingArrayGetterDelegate<long?, DvInt8>(index, x => x ?? DvInt8.NA); + } + else if (outputType.GetElementType() == typeof(short)) + { + Ch.Assert(colType.ItemType == NumberType.I2); + return CreateConvertingArrayGetterDelegate<short, DvInt2>(index, x => x); + } + else if (outputType.GetElementType() == typeof(short?)) + { + Ch.Assert(colType.ItemType == NumberType.I2); + return CreateConvertingArrayGetterDelegate<short?, DvInt2>(index, x => x ?? DvInt2.NA); + } + else if (outputType.GetElementType() == typeof(sbyte)) + { + Ch.Assert(colType.ItemType == NumberType.I1); + return CreateConvertingArrayGetterDelegate<sbyte, DvInt1>(index, x => x); } + else if (outputType.GetElementType() == typeof(sbyte?)) + { + Ch.Assert(colType.ItemType == NumberType.I1); + return CreateConvertingArrayGetterDelegate<sbyte?, DvInt1>(index, x => x ?? DvInt1.NA); + } + else if (outputType.GetElementType() == typeof(bool)) + { + Ch.Assert(colType.ItemType.IsBool); + return CreateConvertingArrayGetterDelegate<bool, DvBool>(index, x => x); + } + else if (outputType.GetElementType() == typeof(bool?)) + { + Ch.Assert(colType.ItemType.IsBool); + return CreateConvertingArrayGetterDelegate<bool?, DvBool>(index, x => x ?? DvBool.NA); + } + // T[] -> VBuffer<T> - Ch.Assert(outputType.GetElementType() == colType.ItemType.RawType); - del = CreateArrayToVBufferGetter<int>; + if (outputType.GetElementType().IsGenericType && outputType.GetElementType().GetGenericTypeDefinition() == typeof(Nullable<>)) + Ch.Assert(Nullable.GetUnderlyingType(outputType.GetElementType()) == colType.ItemType.RawType); + else + Ch.Assert(outputType.GetElementType() == colType.ItemType.RawType); + del = CreateDirectArrayGetterDelegate<int>; + genericType = outputType.GetElementType(); } else if (colType.IsVector) { @@ -142,7 +197,7 @@ private Delegate CreateGetter(int index) Ch.Assert(outputType.IsGenericType); Ch.Assert(outputType.GetGenericTypeDefinition() == typeof(VBuffer<>)); Ch.Assert(outputType.GetGenericArguments()[0] == colType.ItemType.RawType); - del = CreateVBufferToVBufferDelegate<int>; + del = CreateDirectVBufferGetterDelegate<int>; } else if (colType.IsPrimitive) { @@ -150,24 +205,74 @@ private Delegate CreateGetter(int index) { // String -> DvText Ch.Assert(colType.IsText); - return CreateStringToTextGetter(index); + return CreateConvertingGetterDelegate<String, DvText>(index, x => x == null ? DvText.NA : new DvText(x)); } else if (outputType == typeof(bool)) { // Bool -> DvBool Ch.Assert(colType.IsBool); - return CreateBooleanToDvBoolGetter(index); + return CreateConvertingGetterDelegate<bool, DvBool>(index, x => x); } else if (outputType == typeof(bool?)) { // Bool? -> DvBool Ch.Assert(colType.IsBool); - return CreateNullableBooleanToDvBoolGetter(index); + return CreateConvertingGetterDelegate<bool?, DvBool>(index, x => x ?? DvBool.NA); + } + else if (outputType == typeof(int)) + { + // int -> DvInt4 + Ch.Assert(colType == NumberType.I4); + return CreateConvertingGetterDelegate<int, DvInt4>(index, x => x); + } + else if (outputType == typeof(int?)) + { + // int? -> DvInt4 + Ch.Assert(colType == NumberType.I4); + return CreateConvertingGetterDelegate<int?, DvInt4>(index, x => x ?? DvInt4.NA); + } + else if (outputType == typeof(short)) + { + // short -> DvInt2 + Ch.Assert(colType == NumberType.I2); + return CreateConvertingGetterDelegate<short, DvInt2>(index, x => x); + } + else if (outputType == typeof(short?)) + { + // short? -> DvInt2 + Ch.Assert(colType == NumberType.I2); + return CreateConvertingGetterDelegate<short?, DvInt2>(index, x => x ?? DvInt2.NA); + } + else if (outputType == typeof(long)) + { + // long -> DvInt8 + Ch.Assert(colType == NumberType.I8); + return CreateConvertingGetterDelegate<long, DvInt8>(index, x => x); + } + else if (outputType == typeof(long?)) + { + // long? -> DvInt8 + Ch.Assert(colType == NumberType.I8); + return CreateConvertingGetterDelegate<long?, DvInt8>(index, x => x ?? DvInt8.NA); + } + else if (outputType == typeof(sbyte)) + { + // sbyte -> DvInt1 + Ch.Assert(colType == NumberType.I1); + return CreateConvertingGetterDelegate<sbyte, DvInt1>(index, x => x); + } + else if (outputType == typeof(sbyte?)) + { + // sbyte? -> DvInt1 + Ch.Assert(colType == NumberType.I1); + return CreateConvertingGetterDelegate<sbyte?, DvInt1>(index, x => x ?? DvInt1.NA); } - // T -> T - Ch.Assert(colType.RawType == outputType); - del = CreateDirectGetter<int>; + if (outputType.IsGenericType && outputType.GetGenericTypeDefinition() == typeof(Nullable<>)) + Ch.Assert(colType.RawType == Nullable.GetUnderlyingType(outputType)); + else + Ch.Assert(colType.RawType == outputType); + del = CreateDirectGetterDelegate<int>; } else { @@ -175,66 +280,43 @@ private Delegate CreateGetter(int index) throw Ch.ExceptNotImpl("Type '{0}' is not yet supported.", outputType.FullName); } MethodInfo meth = - del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(colType.ItemType.RawType); + del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(genericType); return (Delegate)meth.Invoke(this, new object[] { index }); } - private Delegate CreateStringArrayToVBufferGetter(int index) + // REVIEW: The converting getter invokes a type conversion delegate on every call, so it's inherently slower + // than the 'direct' getter. We don't have good indication of this to the user, and the selection + // of affected types is pretty arbitrary (signed integers and bools, but not uints and floats). + private Delegate CreateConvertingArrayGetterDelegate<TSrc, TDst>(int index, Func<TSrc, TDst> convert) { - var peek = DataView._peeks[index] as Peek<TRow, string[]>; + var peek = DataView._peeks[index] as Peek<TRow, TSrc[]>; Ch.AssertValue(peek); - - string[] buf = null; - - return (ValueGetter<VBuffer<DvText>>)((ref VBuffer<DvText> dst) => + TSrc[] buf = default; + return (ValueGetter<VBuffer<TDst>>)((ref VBuffer<TDst> dst) => { peek(GetCurrentRowObject(), Position, ref buf); var n = Utils.Size(buf); - dst = new VBuffer<DvText>(n, Utils.Size(dst.Values) < n - ? new DvText[n] + dst = new VBuffer<TDst>(n, Utils.Size(dst.Values) < n + ? new TDst[n] : dst.Values, dst.Indices); for (int i = 0; i < n; i++) - dst.Values[i] = new DvText(buf[i]); - }); - } - - private Delegate CreateStringToTextGetter(int index) - { - var peek = DataView._peeks[index] as Peek<TRow, string>; - Ch.AssertValue(peek); - string buf = null; - return (ValueGetter<DvText>)((ref DvText dst) => - { - peek(GetCurrentRowObject(), Position, ref buf); - dst = new DvText(buf); - }); - } - - private Delegate CreateBooleanToDvBoolGetter(int index) - { - var peek = DataView._peeks[index] as Peek<TRow, bool>; - Ch.AssertValue(peek); - bool buf = false; - return (ValueGetter<DvBool>)((ref DvBool dst) => - { - peek(GetCurrentRowObject(), Position, ref buf); - dst = (DvBool)buf; + dst.Values[i] = convert(buf[i]); }); } - private Delegate CreateNullableBooleanToDvBoolGetter(int index) + private Delegate CreateConvertingGetterDelegate<TSrc, TDst>(int index, Func<TSrc, TDst> convert) { - var peek = DataView._peeks[index] as Peek<TRow, bool?>; + var peek = DataView._peeks[index] as Peek<TRow, TSrc>; Ch.AssertValue(peek); - bool? buf = null; - return (ValueGetter<DvBool>)((ref DvBool dst) => + TSrc buf = default; + return (ValueGetter<TDst>)((ref TDst dst) => { peek(GetCurrentRowObject(), Position, ref buf); - dst = buf.HasValue ? (DvBool)buf.Value : DvBool.NA; + dst = convert(buf); }); } - private Delegate CreateArrayToVBufferGetter<TDst>(int index) + private Delegate CreateDirectArrayGetterDelegate<TDst>(int index) { var peek = DataView._peeks[index] as Peek<TRow, TDst[]>; Ch.AssertValue(peek); @@ -250,26 +332,29 @@ private Delegate CreateArrayToVBufferGetter<TDst>(int index) }); } - private Delegate CreateVBufferToVBufferDelegate<TDst>(int index) + private Delegate CreateDirectVBufferGetterDelegate<TDst>(int index) { var peek = DataView._peeks[index] as Peek<TRow, VBuffer<TDst>>; Ch.AssertValue(peek); VBuffer<TDst> buf = default(VBuffer<TDst>); return (ValueGetter<VBuffer<TDst>>)((ref VBuffer<TDst> dst) => - { - // The peek for a VBuffer is just a simple assignment, so there is - // no copy going on in the peek, so we must do that as a second - // step to the destination. - peek(GetCurrentRowObject(), Position, ref buf); - buf.CopyTo(ref dst); - }); + { + // The peek for a VBuffer is just a simple assignment, so there is + // no copy going on in the peek, so we must do that as a second + // step to the destination. + peek(GetCurrentRowObject(), Position, ref buf); + buf.CopyTo(ref dst); + }); } - private Delegate CreateDirectGetter<TDst>(int index) + private Delegate CreateDirectGetterDelegate<TDst>(int index) { var peek = DataView._peeks[index] as Peek<TRow, TDst>; Ch.AssertValue(peek); - return (ValueGetter<TDst>)((ref TDst dst) => { peek(GetCurrentRowObject(), Position, ref dst); }); + return (ValueGetter<TDst>)((ref TDst dst) => + { + peek(GetCurrentRowObject(), Position, ref dst); + }); } protected abstract TRow GetCurrentRowObject(); diff --git a/src/Microsoft.ML.Api/SchemaDefinition.cs b/src/Microsoft.ML.Api/SchemaDefinition.cs index 5f84712625..559e3a81ee 100644 --- a/src/Microsoft.ML.Api/SchemaDefinition.cs +++ b/src/Microsoft.ML.Api/SchemaDefinition.cs @@ -329,6 +329,9 @@ public static SchemaDefinition Create(Type userType) // of later at cursor creation. if (fieldInfo.FieldType == typeof(IChannel)) continue; + // Const fields do not need to be mapped. + if (fieldInfo.IsLiteral) + continue; if (fieldInfo.GetCustomAttribute<NoColumnAttribute>() != null) continue; diff --git a/src/Microsoft.ML.Api/TypedCursor.cs b/src/Microsoft.ML.Api/TypedCursor.cs index 29fee77a02..2ba9eeb23a 100644 --- a/src/Microsoft.ML.Api/TypedCursor.cs +++ b/src/Microsoft.ML.Api/TypedCursor.cs @@ -271,7 +271,7 @@ private Action<TRow> GenerateSetter(IRow input, int index, InternalSchemaDefinit var colType = input.Schema.GetColumnType(index); var fieldInfo = column.FieldInfo; var fieldType = fieldInfo.FieldType; - + var genericType = fieldType; Func<IRow, int, Delegate, Delegate, Action<TRow>> del; if (fieldType.IsArray) { @@ -280,11 +280,66 @@ private Action<TRow> GenerateSetter(IRow input, int index, InternalSchemaDefinit if (fieldType.GetElementType() == typeof(string)) { Ch.Assert(colType.ItemType.IsText); - return CreateVBufferToStringArraySetter(input, index, poke, peek); + return CreateConvertingVBufferSetter<DvText, string>(input, index, poke, peek, x => x.ToString()); + } + else if (fieldType.GetElementType() == typeof(bool)) + { + Ch.Assert(colType.ItemType.IsBool); + return CreateConvertingVBufferSetter<DvBool, bool>(input, index, poke, peek, x => (bool)x); + } + else if (fieldType.GetElementType() == typeof(bool?)) + { + Ch.Assert(colType.ItemType.IsBool); + return CreateConvertingVBufferSetter<DvBool, bool?>(input, index, poke, peek, x => (bool?)x); + } + else if (fieldType.GetElementType() == typeof(int)) + { + Ch.Assert(colType.ItemType == NumberType.I4); + return CreateConvertingVBufferSetter<DvInt4, int>(input, index, poke, peek, x => (int)x); + } + else if (fieldType.GetElementType() == typeof(int?)) + { + Ch.Assert(colType.ItemType == NumberType.I4); + return CreateConvertingVBufferSetter<DvInt4, int?>(input, index, poke, peek, x => (int?)x); + } + else if (fieldType.GetElementType() == typeof(short)) + { + Ch.Assert(colType.ItemType == NumberType.I2); + return CreateConvertingVBufferSetter<DvInt2, short>(input, index, poke, peek, x => (short)x); + } + else if (fieldType.GetElementType() == typeof(short?)) + { + Ch.Assert(colType.ItemType == NumberType.I2); + return CreateConvertingVBufferSetter<DvInt2, short?>(input, index, poke, peek, x => (short?)x); + } + else if (fieldType.GetElementType() == typeof(long)) + { + Ch.Assert(colType.ItemType == NumberType.I8); + return CreateConvertingVBufferSetter<DvInt8, long>(input, index, poke, peek, x => (long)x); + } + else if (fieldType.GetElementType() == typeof(long?)) + { + Ch.Assert(colType.ItemType == NumberType.I8); + return CreateConvertingVBufferSetter<DvInt8, long?>(input, index, poke, peek, x => (long?)x); } + else if (fieldType.GetElementType() == typeof(sbyte)) + { + Ch.Assert(colType.ItemType == NumberType.I1); + return CreateConvertingVBufferSetter<DvInt1, sbyte>(input, index, poke, peek, x => (sbyte)x); + } + else if (fieldType.GetElementType() == typeof(sbyte?)) + { + Ch.Assert(colType.ItemType == NumberType.I1); + return CreateConvertingVBufferSetter<DvInt1, sbyte?>(input, index, poke, peek, x => (sbyte?)x); + } + // VBuffer<T> -> T[] - Ch.Assert(fieldType.GetElementType() == colType.ItemType.RawType); - del = CreateVBufferToArraySetter<int>; + if (fieldType.GetElementType().IsGenericType && fieldType.GetElementType().GetGenericTypeDefinition() == typeof(Nullable<>)) + Ch.Assert(colType.ItemType.RawType == Nullable.GetUnderlyingType(fieldType.GetElementType())); + else + Ch.Assert(colType.ItemType.RawType == fieldType.GetElementType()); + del = CreateDirectVBufferSetter<int>; + genericType = fieldType.GetElementType(); } else if (colType.IsVector) { @@ -302,53 +357,111 @@ private Action<TRow> GenerateSetter(IRow input, int index, InternalSchemaDefinit // DvText -> String Ch.Assert(colType.IsText); Ch.Assert(peek == null); - return CreateTextToStringSetter(input, index, poke); + return CreateConvertingActionSetter<DvText, string>(input, index, poke, x => x.ToString()); } else if (fieldType == typeof(bool)) { Ch.Assert(colType.IsBool); Ch.Assert(peek == null); - return CreateDvBoolToBoolSetter(input, index, poke); + return CreateConvertingActionSetter<DvBool, bool>(input, index, poke, x => (bool)x); } - else + else if (fieldType == typeof(bool?)) { - // T -> T - Ch.Assert(colType.RawType == fieldType); - del = CreateDirectSetter<int>; + Ch.Assert(colType.IsBool); + Ch.Assert(peek == null); + return CreateConvertingActionSetter<DvBool, bool?>(input, index, poke, x => (bool?)x); + } + else if (fieldType == typeof(int)) + { + Ch.Assert(colType == NumberType.I4); + Ch.Assert(peek == null); + return CreateConvertingActionSetter<DvInt4, int>(input, index, poke, x => (int)x); } + else if (fieldType == typeof(int?)) + { + Ch.Assert(colType == NumberType.I4); + Ch.Assert(peek == null); + return CreateConvertingActionSetter<DvInt4, int?>(input, index, poke, x => (int?)x); + } + else if (fieldType == typeof(short)) + { + Ch.Assert(colType == NumberType.I2); + Ch.Assert(peek == null); + return CreateConvertingActionSetter<DvInt2, short>(input, index, poke, x => (short)x); + } + else if (fieldType == typeof(short?)) + { + Ch.Assert(colType == NumberType.I2); + Ch.Assert(peek == null); + return CreateConvertingActionSetter<DvInt2, short?>(input, index, poke, x => (short?)x); + } + else if (fieldType == typeof(long)) + { + Ch.Assert(colType == NumberType.I8); + Ch.Assert(peek == null); + return CreateConvertingActionSetter<DvInt8, long>(input, index, poke, x => (long)x); + } + else if (fieldType == typeof(long?)) + { + Ch.Assert(colType == NumberType.I8); + Ch.Assert(peek == null); + return CreateConvertingActionSetter<DvInt8, long?>(input, index, poke, x => (long?)x); + } + else if (fieldType == typeof(sbyte)) + { + Ch.Assert(colType == NumberType.I1); + Ch.Assert(peek == null); + return CreateConvertingActionSetter<DvInt1, sbyte>(input, index, poke, x => (sbyte)x); + } + else if (fieldType == typeof(sbyte?)) + { + Ch.Assert(colType == NumberType.I1); + Ch.Assert(peek == null); + return CreateConvertingActionSetter<DvInt1, sbyte?>(input, index, poke, x => (sbyte?)x); + } + // T -> T + if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Nullable<>)) + Ch.Assert(colType.RawType == Nullable.GetUnderlyingType(fieldType)); + else + Ch.Assert(colType.RawType == fieldType); + + del = CreateDirectSetter<int>; } else { // REVIEW: Is this even possible? throw Ch.ExceptNotImpl("Type '{0}' is not yet supported.", fieldInfo.FieldType.FullName); } - MethodInfo meth = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(colType.ItemType.RawType); + MethodInfo meth = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(genericType); return (Action<TRow>)meth.Invoke(this, new object[] { input, index, poke, peek }); } - private Action<TRow> CreateVBufferToStringArraySetter(IRow input, int col, Delegate poke, Delegate peek) + // REVIEW: The converting getter invokes a type conversion delegate on every call, so it's inherently slower + // than the 'direct' getter. We don't have good indication of this to the user, and the selection + // of affected types is pretty arbitrary (signed integers and bools, but not uints and floats). + private Action<TRow> CreateConvertingVBufferSetter<TSrc, TDst>(IRow input, int col, Delegate poke, Delegate peek, Func<TSrc, TDst> convert) { - var getter = input.GetGetter<VBuffer<DvText>>(col); - var typedPoke = poke as Poke<TRow, string[]>; - var typedPeek = peek as Peek<TRow, string[]>; + var getter = input.GetGetter<VBuffer<TSrc>>(col); + var typedPoke = poke as Poke<TRow, TDst[]>; + var typedPeek = peek as Peek<TRow, TDst[]>; Contracts.AssertValue(typedPoke); Contracts.AssertValue(typedPeek); - VBuffer<DvText> value = default(VBuffer<DvText>); - string[] buf = null; + VBuffer<TSrc> value = default; + TDst[] buf = null; return row => { getter(ref value); typedPeek(row, Position, ref buf); if (Utils.Size(buf) != value.Length) - buf = new string[value.Length]; + buf = new TDst[value.Length]; foreach (var pair in value.Items(true)) - buf[pair.Key] = pair.Value.ToString(); + buf[pair.Key] = convert(pair.Value); typedPoke(row, buf); }; } - private Action<TRow> CreateVBufferToArraySetter<TDst>(IRow input, int col, Delegate poke, Delegate peek) + private Action<TRow> CreateDirectVBufferSetter<TDst>(IRow input, int col, Delegate poke, Delegate peek) { var getter = input.GetGetter<VBuffer<TDst>>(col); var typedPoke = poke as Poke<TRow, TDst[]>; @@ -386,29 +499,17 @@ private Action<TRow> CreateVBufferToArraySetter<TDst>(IRow input, int col, Deleg }; } - private static Action<TRow> CreateTextToStringSetter(IRow input, int col, Delegate poke) - { - var getter = input.GetGetter<DvText>(col); - var typedPoke = poke as Poke<TRow, string>; - Contracts.AssertValue(typedPoke); - DvText value = default(DvText); - return row => - { - getter(ref value); - typedPoke(row, value.ToString()); - }; - } - - private static Action<TRow> CreateDvBoolToBoolSetter(IRow input, int col, Delegate poke) + private static Action<TRow> CreateConvertingActionSetter<TSrc, TDst>(IRow input, int col, Delegate poke, Func<TSrc, TDst> convert) { - var getter = input.GetGetter<DvBool>(col); - var typedPoke = poke as Poke<TRow, bool>; + var getter = input.GetGetter<TSrc>(col); + var typedPoke = poke as Poke<TRow, TDst>; Contracts.AssertValue(typedPoke); - DvBool value = default(DvBool); + TSrc value = default; return row => { getter(ref value); - typedPoke(row, Convert.ToBoolean(value.RawValue)); + var toPoke = convert(value); + typedPoke(row, toPoke); }; } diff --git a/src/Microsoft.ML.Core/Data/DataKind.cs b/src/Microsoft.ML.Core/Data/DataKind.cs index 5ed5ded1c1..358227399b 100644 --- a/src/Microsoft.ML.Core/Data/DataKind.cs +++ b/src/Microsoft.ML.Core/Data/DataKind.cs @@ -83,22 +83,22 @@ public static ulong ToMaxInt(this DataKind kind) { switch (kind) { - case DataKind.I1: - return (ulong)sbyte.MaxValue; - case DataKind.U1: - return byte.MaxValue; - case DataKind.I2: - return (ulong)short.MaxValue; - case DataKind.U2: - return ushort.MaxValue; - case DataKind.I4: - return int.MaxValue; - case DataKind.U4: - return uint.MaxValue; - case DataKind.I8: - return long.MaxValue; - case DataKind.U8: - return ulong.MaxValue; + case DataKind.I1: + return (ulong)sbyte.MaxValue; + case DataKind.U1: + return byte.MaxValue; + case DataKind.I2: + return (ulong)short.MaxValue; + case DataKind.U2: + return ushort.MaxValue; + case DataKind.I4: + return int.MaxValue; + case DataKind.U4: + return uint.MaxValue; + case DataKind.I8: + return long.MaxValue; + case DataKind.U8: + return ulong.MaxValue; } return 0; @@ -112,22 +112,22 @@ public static long ToMinInt(this DataKind kind) { switch (kind) { - case DataKind.I1: - return sbyte.MinValue; - case DataKind.U1: - return byte.MinValue; - case DataKind.I2: - return short.MinValue; - case DataKind.U2: - return ushort.MinValue; - case DataKind.I4: - return int.MinValue; - case DataKind.U4: - return uint.MinValue; - case DataKind.I8: - return long.MinValue; - case DataKind.U8: - return 0; + case DataKind.I1: + return sbyte.MinValue; + case DataKind.U1: + return byte.MinValue; + case DataKind.I2: + return short.MinValue; + case DataKind.U2: + return ushort.MinValue; + case DataKind.I4: + return int.MinValue; + case DataKind.U4: + return uint.MinValue; + case DataKind.I8: + return long.MinValue; + case DataKind.U8: + return 0; } return 1; @@ -140,38 +140,38 @@ public static Type ToType(this DataKind kind) { switch (kind) { - case DataKind.I1: - return typeof(DvInt1); - case DataKind.U1: - return typeof(byte); - case DataKind.I2: - return typeof(DvInt2); - case DataKind.U2: - return typeof(ushort); - case DataKind.I4: - return typeof(DvInt4); - case DataKind.U4: - return typeof(uint); - case DataKind.I8: - return typeof(DvInt8); - case DataKind.U8: - return typeof(ulong); - case DataKind.R4: - return typeof(Single); - case DataKind.R8: - return typeof(Double); - case DataKind.TX: - return typeof(DvText); - case DataKind.BL: - return typeof(DvBool); - case DataKind.TS: - return typeof(DvTimeSpan); - case DataKind.DT: - return typeof(DvDateTime); - case DataKind.DZ: - return typeof(DvDateTimeZone); - case DataKind.UG: - return typeof(UInt128); + case DataKind.I1: + return typeof(DvInt1); + case DataKind.U1: + return typeof(byte); + case DataKind.I2: + return typeof(DvInt2); + case DataKind.U2: + return typeof(ushort); + case DataKind.I4: + return typeof(DvInt4); + case DataKind.U4: + return typeof(uint); + case DataKind.I8: + return typeof(DvInt8); + case DataKind.U8: + return typeof(ulong); + case DataKind.R4: + return typeof(Single); + case DataKind.R8: + return typeof(Double); + case DataKind.TX: + return typeof(DvText); + case DataKind.BL: + return typeof(DvBool); + case DataKind.TS: + return typeof(DvTimeSpan); + case DataKind.DT: + return typeof(DvDateTime); + case DataKind.DZ: + return typeof(DvDateTimeZone); + case DataKind.UG: + return typeof(UInt128); } return null; @@ -185,29 +185,29 @@ public static bool TryGetDataKind(this Type type, out DataKind kind) Contracts.CheckValueOrNull(type); // REVIEW: Make this more efficient. Should we have a global dictionary? - if (type == typeof(DvInt1)) + if (type == typeof(DvInt1) || type == typeof(sbyte) || type == typeof(sbyte?)) kind = DataKind.I1; - else if (type == typeof(byte)) + else if (type == typeof(byte) || type == typeof(byte?)) kind = DataKind.U1; - else if (type == typeof(DvInt2)) + else if (type == typeof(DvInt2)|| type== typeof(short) || type == typeof(short?)) kind = DataKind.I2; - else if (type == typeof(ushort)) + else if (type == typeof(ushort)|| type == typeof(ushort?)) kind = DataKind.U2; - else if (type == typeof(DvInt4)) + else if (type == typeof(DvInt4) || type == typeof(int)|| type == typeof(int?)) kind = DataKind.I4; - else if (type == typeof(uint)) + else if (type == typeof(uint)|| type == typeof(uint?)) kind = DataKind.U4; - else if (type == typeof(DvInt8)) + else if (type == typeof(DvInt8) || type==typeof(long)|| type == typeof(long?)) kind = DataKind.I8; - else if (type == typeof(ulong)) + else if (type == typeof(ulong)|| type == typeof(ulong?)) kind = DataKind.U8; - else if (type == typeof(Single)) + else if (type == typeof(Single)|| type == typeof(Single?)) kind = DataKind.R4; - else if (type == typeof(Double)) + else if (type == typeof(Double)|| type == typeof(Double?)) kind = DataKind.R8; else if (type == typeof(DvText)) kind = DataKind.TX; - else if (type == typeof(DvBool) || type == typeof(bool) ||type ==typeof(bool?)) + else if (type == typeof(DvBool) || type == typeof(bool) || type == typeof(bool?)) kind = DataKind.BL; else if (type == typeof(DvTimeSpan)) kind = DataKind.TS; @@ -234,38 +234,38 @@ public static string GetString(this DataKind kind) { switch (kind) { - case DataKind.I1: - return "I1"; - case DataKind.I2: - return "I2"; - case DataKind.I4: - return "I4"; - case DataKind.I8: - return "I8"; - case DataKind.U1: - return "U1"; - case DataKind.U2: - return "U2"; - case DataKind.U4: - return "U4"; - case DataKind.U8: - return "U8"; - case DataKind.R4: - return "R4"; - case DataKind.R8: - return "R8"; - case DataKind.BL: - return "BL"; - case DataKind.TX: - return "TX"; - case DataKind.TS: - return "TS"; - case DataKind.DT: - return "DT"; - case DataKind.DZ: - return "DZ"; - case DataKind.UG: - return "UG"; + case DataKind.I1: + return "I1"; + case DataKind.I2: + return "I2"; + case DataKind.I4: + return "I4"; + case DataKind.I8: + return "I8"; + case DataKind.U1: + return "U1"; + case DataKind.U2: + return "U2"; + case DataKind.U4: + return "U4"; + case DataKind.U8: + return "U8"; + case DataKind.R4: + return "R4"; + case DataKind.R8: + return "R8"; + case DataKind.BL: + return "BL"; + case DataKind.TX: + return "TX"; + case DataKind.TS: + return "TS"; + case DataKind.DT: + return "DT"; + case DataKind.DZ: + return "DZ"; + case DataKind.UG: + return "UG"; } return ""; } diff --git a/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs b/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs index 42b85ae20f..87e23952d6 100644 --- a/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs +++ b/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs @@ -9,8 +9,10 @@ using Microsoft.ML.TestFramework; using Microsoft.ML.Trainers; using Microsoft.ML.Transforms; +using System; using System.Collections.Generic; using System.Linq; +using System.Reflection; using Xunit; using Xunit.Abstractions; @@ -205,5 +207,407 @@ public class IrisPrediction public float[] PredictedLabels; } + public class ConversionSimpleClass + { + public int fInt; + public uint fuInt; + public short fShort; + public ushort fuShort; + public sbyte fsByte; + public byte fByte; + public long fLong; + public ulong fuLong; + public float fFloat; + public double fDouble; + public bool fBool; + public string fString; + } + + public class ConversionNullalbeClass + { + public int? fInt; + public uint? fuInt; + public short? fShort; + public ushort? fuShort; + public sbyte? fsByte; + public byte? fByte; + public long? fLong; + public ulong? fuLong; + public float? fFloat; + public double? fDouble; + public bool? fBool; + public string fString; + } + + public bool CompareObjectValues(object x, object y, Type type) + { + // By default behaviour for DvText is to be empty string, while for string is null. + // So if we do roundtrip string-> DvText -> string all null string become empty strings. + // Therefore replace all null values to empty string if field is string. + if (type == typeof(string) && x == null) + x = ""; + if (type == typeof(string) && y == null) + y = ""; + if (x == null && y == null) + return true; + if (x == null && y != null) + return false; + return x.Equals(y); + } + + public bool CompareThroughReflection<T>(T x, T y) + { + foreach (var field in typeof(T).GetFields()) + { + var xvalue = field.GetValue(x); + var yvalue = field.GetValue(y); + if (field.FieldType.IsArray) + { + if (!CompareArrayValues(xvalue as Array, yvalue as Array)) + return false; + } + else + { + if (!CompareObjectValues(xvalue, yvalue, field.FieldType)) + return false; + } + } + return true; + } + + public bool CompareArrayValues(Array x, Array y) + { + if (x == null && y == null) return true; + if ((x == null && y != null) || (y == null && x != null)) + return false; + if (x.Length != y.Length) + return false; + for (int i = 0; i < x.Length; i++) + if (!CompareObjectValues(x.GetValue(i), y.GetValue(i), x.GetType().GetElementType())) + return false; + return true; + } + + public class ClassWithConstField + { + public const string ConstString = "N"; + public string fString; + public const int ConstInt = 100; + public int fInt; + } + + [Fact] + public void RoundTripConversionWithBasicTypes() + { + var data = new List<ConversionSimpleClass> + { + new ConversionSimpleClass() + { + fInt = int.MaxValue - 1, + fuInt = uint.MaxValue - 1, + fBool = true, + fsByte = sbyte.MaxValue - 1, + fByte = byte.MaxValue - 1, + fDouble = double.MaxValue - 1, + fFloat = float.MaxValue - 1, + fLong = long.MaxValue - 1, + fuLong = ulong.MaxValue - 1, + fShort = short.MaxValue - 1, + fuShort = ushort.MaxValue - 1, + fString = null + }, + new ConversionSimpleClass() + { + fInt = int.MaxValue, + fuInt = uint.MaxValue, + fBool = true, + fsByte = sbyte.MaxValue, + fByte = byte.MaxValue, + fDouble = double.MaxValue, + fFloat = float.MaxValue, + fLong = long.MaxValue, + fuLong = ulong.MaxValue, + fShort = short.MaxValue, + fuShort = ushort.MaxValue, + fString = "ooh" + }, + new ConversionSimpleClass() + { + fInt = int.MinValue + 1, + fuInt = uint.MinValue + 1, + fBool = false, + fsByte = sbyte.MinValue + 1, + fByte = byte.MinValue + 1, + fDouble = double.MinValue + 1, + fFloat = float.MinValue + 1, + fLong = long.MinValue + 1, + fuLong = ulong.MinValue + 1, + fShort = short.MinValue + 1, + fuShort = ushort.MinValue + 1, + fString = "" + }, + new ConversionSimpleClass() + }; + + var dataNullable = new List<ConversionNullalbeClass> + { + new ConversionNullalbeClass() + { + fInt = int.MaxValue - 1, + fuInt = uint.MaxValue - 1, + fBool = true, + fsByte = sbyte.MaxValue - 1, + fByte = byte.MaxValue - 1, + fDouble = double.MaxValue - 1, + fFloat = float.MaxValue - 1, + fLong = long.MaxValue - 1, + fuLong = ulong.MaxValue - 1, + fShort = short.MaxValue - 1, + fuShort = ushort.MaxValue - 1, + fString = "ha" + }, + new ConversionNullalbeClass() + { + fInt = int.MaxValue, + fuInt = uint.MaxValue, + fBool = true, + fsByte = sbyte.MaxValue, + fByte = byte.MaxValue, + fDouble = double.MaxValue, + fFloat = float.MaxValue, + fLong = long.MaxValue, + fuLong = ulong.MaxValue, + fShort = short.MaxValue, + fuShort = ushort.MaxValue, + fString = "ooh" + }, + new ConversionNullalbeClass() + { + fInt = int.MinValue + 1, + fuInt = uint.MinValue, + fBool = false, + fsByte = sbyte.MinValue + 1, + fByte = byte.MinValue, + fDouble = double.MinValue + 1, + fFloat = float.MinValue + 1, + fLong = long.MinValue + 1, + fuLong = ulong.MinValue, + fShort = short.MinValue + 1, + fuShort = ushort.MinValue, + fString = "" + }, + new ConversionNullalbeClass() + }; + + using (var env = new TlcEnvironment()) + { + var dataView = ComponentCreation.CreateDataView(env, data); + var enumeratorSimple = dataView.AsEnumerable<ConversionSimpleClass>(env, false).GetEnumerator(); + var originalEnumerator = data.GetEnumerator(); + while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext()) + { + Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current)); + } + Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); + + dataView = ComponentCreation.CreateDataView(env, dataNullable); + var enumeratorNullable = dataView.AsEnumerable<ConversionNullalbeClass>(env, false).GetEnumerator(); + var originalNullableEnumerator = dataNullable.GetEnumerator(); + while (enumeratorNullable.MoveNext() && originalNullableEnumerator.MoveNext()) + { + Assert.True(CompareThroughReflection(enumeratorNullable.Current, originalNullableEnumerator.Current)); + } + Assert.True(!enumeratorNullable.MoveNext() && !originalNullableEnumerator.MoveNext()); + } + } + + public class ConversionNotSupportedMinValueClass + { + public int fInt; + public long fLong; + public short fShort; + public sbyte fSByte; + } + + [Fact] + public void ConversionExceptionsBehavior() + { + using (var env = new TlcEnvironment()) + { + var data = new ConversionNotSupportedMinValueClass[1]; + foreach (var field in typeof(ConversionNotSupportedMinValueClass).GetFields()) + { + data[0] = new ConversionNotSupportedMinValueClass(); + FieldInfo fi; + if ((fi = field.FieldType.GetField("MinValue")) != null) + { + field.SetValue(data[0], fi.GetValue(null)); + } + var dataView = ComponentCreation.CreateDataView(env, data); + var enumerator = dataView.AsEnumerable<ConversionNotSupportedMinValueClass>(env, false).GetEnumerator(); + try + { + enumerator.MoveNext(); + Assert.True(false); + } + catch + { + } + } + } + } + + public class ConversionLossMinValueClass + { + public int? fInt; + public long? fLong; + public short? fShort; + public sbyte? fSByte; + } + + [Fact] + public void ConversionMinValueToNullBehavior() + { + using (var env = new TlcEnvironment()) + { + + var data = new List<ConversionLossMinValueClass> + { + new ConversionLossMinValueClass() { fSByte = null, fInt = null, fLong = null, fShort = null }, + new ConversionLossMinValueClass() { fSByte = sbyte.MinValue, fInt = int.MinValue, fLong = long.MinValue, fShort = short.MinValue } + }; + foreach (var field in typeof(ConversionLossMinValueClass).GetFields()) + { + var dataView = ComponentCreation.CreateDataView(env, data); + var enumerator = dataView.AsEnumerable<ConversionLossMinValueClass>(env, false).GetEnumerator(); + while (enumerator.MoveNext()) + { + Assert.True(enumerator.Current.fInt == null && enumerator.Current.fLong == null && + enumerator.Current.fSByte == null && enumerator.Current.fShort == null); + } + } + } + } + + [Fact] + public void ClassWithConstFieldsConversion() + { + var data = new List<ClassWithConstField>() + { + new ClassWithConstField(){ fInt=1, fString ="lala" }, + new ClassWithConstField(){ fInt=-1, fString ="" }, + new ClassWithConstField(){ fInt=0, fString =null } + }; + + using (var env = new TlcEnvironment()) + { + var dataView = ComponentCreation.CreateDataView(env, data); + var enumeratorSimple = dataView.AsEnumerable<ClassWithConstField>(env, false).GetEnumerator(); + var originalEnumerator = data.GetEnumerator(); + while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext()) + Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current)); + Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); + } + } + + public class ClassWithArrays + { + public string[] fString; + public int[] fInt; + public uint[] fuInt; + public short[] fShort; + public ushort[] fuShort; + public sbyte[] fsByte; + public byte[] fByte; + public long[] fLong; + public ulong[] fuLong; + public float[] fFloat; + public double[] fDouble; + public bool[] fBool; + } + + public class ClassWithNullableArrays + { + public string[] fString; + public int?[] fInt; + public uint?[] fuInt; + public short?[] fShort; + public ushort?[] fuShort; + public sbyte?[] fsByte; + public byte?[] fByte; + public long?[] fLong; + public ulong?[] fuLong; + public float?[] fFloat; + public double?[] fDouble; + public bool?[] fBool; + } + + [Fact] + public void RoundTripConversionWithArrays() + { + + var data = new List<ClassWithArrays> + { + new ClassWithArrays() + { + fInt = new int[3] { 0, 1, 2 }, + fFloat = new float[3] { -0.99f, 0f, 0.99f }, + fString = new string[2] { "hola", "lola" }, + fBool = new bool[2] { true, false }, + fByte = new byte[3] { 0, 124, 255 }, + fDouble = new double[3] { -1, 0, 1 }, + fLong = new long[] { 0, 1, 2 }, + fsByte = new sbyte[3] { -127, 127, 0 }, + fShort = new short[3] { 0, 1225, 32767 }, + fuInt = new uint[2] { 0, uint.MaxValue }, + fuLong = new ulong[2] { ulong.MaxValue, 0 }, + fuShort = new ushort[2] { 0, ushort.MaxValue } + }, + new ClassWithArrays() { fInt = new int[3] { -2, 1, 0 }, fFloat = new float[3] { 0.99f, 0f, -0.99f }, fString = new string[2] { "", null } }, + new ClassWithArrays() + }; + + var nullableData = new List<ClassWithNullableArrays> + { + new ClassWithNullableArrays() + { + fInt = new int?[3] { null, -1, 1 }, + fFloat = new float?[3] { -0.99f, null, 0.99f }, + fString = new string[2] { null, "" }, + fBool = new bool?[3] { true, null, false }, + fByte = new byte?[4] { 0, 125, null, 255 }, + fDouble = new double?[3] { -1, null, 1 }, + fLong = new long?[] { null, -1, 1 }, + fsByte = new sbyte?[3] { -127, 127, null }, + fShort = new short?[3] { 0, null, 32767 }, + fuInt = new uint?[4] { null, 42, 0, uint.MaxValue }, + fuLong = new ulong?[3] { ulong.MaxValue, null, 0 }, + fuShort = new ushort?[3] { 0, null, ushort.MaxValue } + }, + new ClassWithNullableArrays() { fInt = new int?[3] { -2, 1, 0 }, fFloat = new float?[3] { 0.99f, 0f, -0.99f }, fString = new string[2] { "lola", "hola" } }, + new ClassWithNullableArrays() + }; + + using (var env = new TlcEnvironment()) + { + var dataView = ComponentCreation.CreateDataView(env, data); + var enumeratorSimple = dataView.AsEnumerable<ClassWithArrays>(env, false).GetEnumerator(); + var originalEnumerator = data.GetEnumerator(); + while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext()) + { + Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current)); + } + Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); + + var nullableDataView = ComponentCreation.CreateDataView(env, nullableData); + var enumeratorNullable = nullableDataView.AsEnumerable<ClassWithNullableArrays>(env, false).GetEnumerator(); + var originalNullalbleEnumerator = nullableData.GetEnumerator(); + while (enumeratorNullable.MoveNext() && originalNullalbleEnumerator.MoveNext()) + { + Assert.True(CompareThroughReflection(enumeratorNullable.Current, originalNullalbleEnumerator.Current)); + } + Assert.True(!enumeratorNullable.MoveNext() && !originalNullalbleEnumerator.MoveNext()); + } + } } }