diff --git a/Microsoft.ML.sln b/Microsoft.ML.sln index 3d285d06d8..140c93753c 100644 --- a/Microsoft.ML.sln +++ b/Microsoft.ML.sln @@ -97,6 +97,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CodeAnalyzer", EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CodeAnalyzer.Tests", "test\Microsoft.ML.CodeAnalyzer.Tests\Microsoft.ML.CodeAnalyzer.Tests.csproj", "{3E4ABF07-7970-4BE6-B45B-A13D3C397545}" EndProject +Project("{F2A71F9B-5D33-465A-A702-920D77279786}") = "Microsoft.ML.FSharp.Tests", "test\Microsoft.ML.FSharp.Tests\Microsoft.ML.FSharp.Tests.fsproj", "{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}" +EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.ImageAnalytics", "src\Microsoft.ML.ImageAnalytics\Microsoft.ML.ImageAnalytics.csproj", "{00E38F77-1E61-4CDF-8F97-1417D4E85053}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.HalLearners", "src\Microsoft.ML.HalLearners\Microsoft.ML.HalLearners.csproj", "{A7222F41-1CF0-47D9-B80C-B4D77B027A61}" @@ -333,6 +335,14 @@ Global {3E4ABF07-7970-4BE6-B45B-A13D3C397545}.Release|Any CPU.Build.0 = Release|Any CPU {3E4ABF07-7970-4BE6-B45B-A13D3C397545}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU {3E4ABF07-7970-4BE6-B45B-A13D3C397545}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU + {802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Debug|Any CPU.Build.0 = Debug|Any CPU + {802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU + {802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU + {802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Release|Any CPU.ActiveCfg = Release|Any CPU + {802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Release|Any CPU.Build.0 = Release|Any CPU + {802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU + {802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU {00E38F77-1E61-4CDF-8F97-1417D4E85053}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {00E38F77-1E61-4CDF-8F97-1417D4E85053}.Debug|Any CPU.Build.0 = Debug|Any CPU {00E38F77-1E61-4CDF-8F97-1417D4E85053}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU @@ -387,6 +397,7 @@ Global {BF66A305-DF10-47E4-8D81-42049B149D2B} = {D3D38B03-B557-484D-8348-8BADEE4DF592} {B4E55B2D-2A92-46E7-B72F-E76D6FD83440} = {7F13E156-3EBA-4021-84A5-CD56BA72F99E} {3E4ABF07-7970-4BE6-B45B-A13D3C397545} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} + {802233D6-8CC0-46AD-9F23-FEE1E9AED9B3} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} {00E38F77-1E61-4CDF-8F97-1417D4E85053} = {09EADF06-BE25-4228-AB53-95AE3E15B530} {A7222F41-1CF0-47D9-B80C-B4D77B027A61} = {09EADF06-BE25-4228-AB53-95AE3E15B530} EndGlobalSection diff --git a/src/Microsoft.ML.Api/ApiUtils.cs b/src/Microsoft.ML.Api/ApiUtils.cs index 8b8cb5871b..96e821f16e 100644 --- a/src/Microsoft.ML.Api/ApiUtils.cs +++ b/src/Microsoft.ML.Api/ApiUtils.cs @@ -51,14 +51,31 @@ private static OpCode GetAssignmentOpCode(Type t) /// internal static Delegate GeneratePeek(InternalSchemaDefinition.Column column) { - var fieldInfo = column.FieldInfo; - Type fieldType = fieldInfo.FieldType; - - var assignmentOpCode = GetAssignmentOpCode(fieldType); - Func func = GeneratePeek; - var methInfo = func.GetMethodInfo().GetGenericMethodDefinition() - .MakeGenericMethod(typeof(TOwn), typeof(TRow), fieldType); - return (Delegate)methInfo.Invoke(null, new object[] { fieldInfo, assignmentOpCode }); + switch (column.MemberInfo) + { + case FieldInfo fieldInfo: + Type fieldType = fieldInfo.FieldType; + + var assignmentOpCode = GetAssignmentOpCode(fieldType); + Func func = GeneratePeek; + var methInfo = func.GetMethodInfo().GetGenericMethodDefinition() + .MakeGenericMethod(typeof(TOwn), typeof(TRow), fieldType); + return (Delegate)methInfo.Invoke(null, new object[] { fieldInfo, assignmentOpCode }); + + case PropertyInfo propertyInfo: + Type propertyType = propertyInfo.PropertyType; + + var assignmentOpCodeProp = GetAssignmentOpCode(propertyType); + Func funcProp = GeneratePeek; + var methInfoProp = funcProp.GetMethodInfo().GetGenericMethodDefinition() + .MakeGenericMethod(typeof(TOwn), typeof(TRow), propertyType); + return (Delegate)methInfoProp.Invoke(null, new object[] { propertyInfo, assignmentOpCodeProp }); + + default: + Contracts.Assert(false); + throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo"); + + } } private static Delegate GeneratePeek(FieldInfo fieldInfo, OpCode assignmentOpCode) @@ -81,6 +98,28 @@ private static Delegate GeneratePeek(FieldInfo fieldInfo, Op return mb.CreateDelegate(typeof(Peek)); } + private static Delegate GeneratePeek(PropertyInfo propertyInfo, OpCode assignmentOpCode) + { + // REVIEW: It seems like we really should cache these, instead of generating them per cursor. + Type[] args = { typeof(TOwn), typeof(TRow), typeof(long), typeof(TValue).MakeByRefType() }; + var mb = new DynamicMethod("Peek", null, args, typeof(TOwn), true); + var il = mb.GetILGenerator(); + var minfo = propertyInfo.GetGetMethod(); + var opcode = (minfo.IsVirtual || minfo.IsAbstract) ? OpCodes.Callvirt : OpCodes.Call; + + il.Emit(OpCodes.Ldarg_3); // push arg3 + il.Emit(OpCodes.Ldarg_1); // push arg1 + il.Emit(opcode, minfo); // call [stack top].get_[propertyInfo]() + // Stobj needs to coupled with a type. + if (assignmentOpCode == OpCodes.Stobj) // [stack top-1] = [stack top] + il.Emit(assignmentOpCode, propertyInfo.PropertyType); + else + il.Emit(assignmentOpCode); + il.Emit(OpCodes.Ret); // ret + + return mb.CreateDelegate(typeof(Peek)); + } + /// /// Each of the specialized 'poke' methods sets the appropriate field value of an instance of T /// to the provided value. So, the call is 'peek(userObject, providedValue)' and the logic is @@ -88,14 +127,30 @@ private static Delegate GeneratePeek(FieldInfo fieldInfo, Op /// internal static Delegate GeneratePoke(InternalSchemaDefinition.Column column) { - var fieldInfo = column.FieldInfo; - Type fieldType = fieldInfo.FieldType; - - var assignmentOpCode = GetAssignmentOpCode(fieldType); - Func func = GeneratePoke; - var methInfo = func.GetMethodInfo().GetGenericMethodDefinition() - .MakeGenericMethod(typeof(TOwn), typeof(TRow), fieldType); - return (Delegate)methInfo.Invoke(null, new object[] { fieldInfo, assignmentOpCode }); + switch (column.MemberInfo) + { + case FieldInfo fieldInfo: + Type fieldType = fieldInfo.FieldType; + + var assignmentOpCode = GetAssignmentOpCode(fieldType); + Func func = GeneratePoke; + var methInfo = func.GetMethodInfo().GetGenericMethodDefinition() + .MakeGenericMethod(typeof(TOwn), typeof(TRow), fieldType); + return (Delegate)methInfo.Invoke(null, new object[] { fieldInfo, assignmentOpCode }); + + case PropertyInfo propertyInfo: + Type propertyType = propertyInfo.PropertyType; + + var assignmentOpCodeProp = GetAssignmentOpCode(propertyType); + Func funcProp = GeneratePoke; + var methInfoProp = funcProp.GetMethodInfo().GetGenericMethodDefinition() + .MakeGenericMethod(typeof(TOwn), typeof(TRow), propertyType); + return (Delegate)methInfoProp.Invoke(null, new object[] { propertyInfo }); + + default: + Contracts.Assert(false); + throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo"); + } } private static Delegate GeneratePoke(FieldInfo fieldInfo, OpCode assignmentOpCode) @@ -115,5 +170,20 @@ private static Delegate GeneratePoke(FieldInfo fieldInfo, Op il.Emit(OpCodes.Ret); // ret return mb.CreateDelegate(typeof(Poke), null); } + + private static Delegate GeneratePoke(PropertyInfo propertyInfo) + { + Type[] args = { typeof(TOwn), typeof(TRow), typeof(TValue) }; + var mb = new DynamicMethod("Poke", null, args, typeof(TOwn), true); + var il = mb.GetILGenerator(); + var minfo = propertyInfo.GetSetMethod(); + var opcode = (minfo.IsVirtual || minfo.IsAbstract) ? OpCodes.Callvirt : OpCodes.Call; + + il.Emit(OpCodes.Ldarg_1); // push arg1 + il.Emit(OpCodes.Ldarg_2); // push arg2 + il.Emit(opcode, minfo); // call [stack top-1].set_[propertyInfo]([stack top]) + il.Emit(OpCodes.Ret); // ret + return mb.CreateDelegate(typeof(Poke), null); + } } } diff --git a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs index e940ea9d4d..c50e48e16f 100644 --- a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs @@ -118,7 +118,7 @@ private Delegate CreateGetter(int index) var colType = DataView.Schema.GetColumnType(index); var column = DataView._schema.SchemaDefn.Columns[index]; - var outputType = column.IsComputed ? column.ReturnType : column.FieldInfo.FieldType; + var outputType = column.OutputType; var genericType = outputType; Func del; diff --git a/src/Microsoft.ML.Api/InternalSchemaDefinition.cs b/src/Microsoft.ML.Api/InternalSchemaDefinition.cs index 3edf7599a4..4c20f25d62 100644 --- a/src/Microsoft.ML.Api/InternalSchemaDefinition.cs +++ b/src/Microsoft.ML.Api/InternalSchemaDefinition.cs @@ -23,21 +23,23 @@ internal sealed class InternalSchemaDefinition public class Column { public readonly string ColumnName; - public readonly FieldInfo FieldInfo; + public readonly MemberInfo MemberInfo; public readonly ParameterInfo ReturnParameterInfo; public readonly ColumnType ColumnType; public readonly bool IsComputed; public readonly Delegate Generator; private readonly Dictionary _metadata; public Dictionary Metadata { get { return _metadata; } } - public Type ReturnType {get { return ReturnParameterInfo.ParameterType.GetElementType(); }} + public Type ComputedReturnType {get { return ReturnParameterInfo.ParameterType.GetElementType(); }} + public Type FieldOrPropertyType => (MemberInfo is FieldInfo) ? (MemberInfo as FieldInfo).FieldType : (MemberInfo as PropertyInfo).PropertyType; + public Type OutputType => IsComputed ? ComputedReturnType : FieldOrPropertyType; - public Column(string columnName, ColumnType columnType, FieldInfo fieldInfo) : - this(columnName, columnType, fieldInfo, null, null) { } + public Column(string columnName, ColumnType columnType, MemberInfo memberInfo) : + this(columnName, columnType, memberInfo, null, null) { } - public Column(string columnName, ColumnType columnType, FieldInfo fieldInfo, + public Column(string columnName, ColumnType columnType, MemberInfo memberInfo, Dictionary metadataInfos) : - this(columnName, columnType, fieldInfo, null, metadataInfos) { } + this(columnName, columnType, memberInfo, null, metadataInfos) { } public Column(string columnName, ColumnType columnType, Delegate generator) : this(columnName, columnType, null, generator, null) { } @@ -46,7 +48,7 @@ public Column(string columnName, ColumnType columnType, Delegate generator, Dictionary metadataInfos) : this(columnName, columnType, null, generator, metadataInfos) { } - private Column(string columnName, ColumnType columnType, FieldInfo fieldInfo = null, + private Column(string columnName, ColumnType columnType, MemberInfo memberInfo = null, Delegate generator = null, Dictionary metadataInfos = null) { Contracts.AssertNonEmpty(columnName); @@ -55,8 +57,8 @@ private Column(string columnName, ColumnType columnType, FieldInfo fieldInfo = n if (generator == null) { - Contracts.AssertValue(fieldInfo); - FieldInfo = fieldInfo; + Contracts.AssertValue(memberInfo); + MemberInfo = memberInfo; } else { @@ -95,8 +97,8 @@ public void AssertRep() // If Column is computed type, it must have a generator. Contracts.Assert(IsComputed == (Generator != null)); - // Column must have either a generator or a fieldInfo value. - Contracts.Assert((Generator == null) != (FieldInfo == null)); + // Column must have either a generator or a memberInfo value. + Contracts.Assert((Generator == null) != (MemberInfo == null)); // Additional Checks if there is a generator. if (Generator == null) @@ -115,9 +117,7 @@ public void AssertRep() Contracts.Assert(Generator.GetMethodInfo().ReturnType == typeof(void)); // Checks that the return type of the generator is compatible with ColumnType. - bool isVector; - DataKind datakind; - GetVectorAndKind(ReturnType, "return type", out isVector, out datakind); + GetVectorAndKind(ComputedReturnType, "return type", out bool isVector, out DataKind datakind); Contracts.Assert(isVector == ColumnType.IsVector); Contracts.Assert(datakind == ColumnType.ItemType.RawKind); } @@ -131,19 +131,30 @@ private InternalSchemaDefinition(Column[] columns) } /// - /// Given a field info on a type, returns whether this appears to be a vector type, + /// Given a field or property info on a type, returns whether this appears to be a vector type, /// and also the associated data kind for this type. If a data kind could not /// be determined, this will throw. /// - /// The field info to inspect. + /// The field or property info to inspect. /// Whether this appears to be a vector type. /// The data kind of the type, or items of this type if vector. - public static void GetVectorAndKind(FieldInfo fieldInfo, out bool isVector, out DataKind kind) + public static void GetVectorAndKind(MemberInfo memberInfo, out bool isVector, out DataKind kind) { - Contracts.AssertValue(fieldInfo); - Type rawFieldType = fieldInfo.FieldType; - var name = fieldInfo.Name; - GetVectorAndKind(rawFieldType, name, out isVector, out kind); + Contracts.AssertValue(memberInfo); + switch (memberInfo) + { + case FieldInfo fieldInfo: + GetVectorAndKind(fieldInfo.FieldType, fieldInfo.Name, out isVector, out kind); + break; + + case PropertyInfo propertyInfo: + GetVectorAndKind(propertyInfo.PropertyType, propertyInfo.Name, out isVector, out kind); + break; + + default: + Contracts.Assert(false); + throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo"); + } } /// @@ -211,23 +222,27 @@ public static InternalSchemaDefinition Create(Type userType, SchemaDefinition us bool isVector; DataKind kind; - FieldInfo fieldInfo = null; + MemberInfo memberInfo = null; if (!col.IsComputed) { - fieldInfo = userType.GetField(col.MemberName); + memberInfo = userType.GetField(col.MemberName); + + if (memberInfo == null) + memberInfo = userType.GetProperty(col.MemberName); - if (fieldInfo == null) - throw Contracts.ExceptParam(nameof(userSchemaDefinition), "No field with name '{0}' found in type '{1}'", + if (memberInfo == null) + throw Contracts.ExceptParam(nameof(userSchemaDefinition), "No field or property with name '{0}' found in type '{1}'", col.MemberName, userType.FullName); //Clause to handle the field that may be used to expose the cursor channel. //This field does not need a column. - if (fieldInfo.FieldType == typeof(IChannel)) + if ( (memberInfo is FieldInfo && (memberInfo as FieldInfo).FieldType == typeof(IChannel)) || + (memberInfo is PropertyInfo && (memberInfo as PropertyInfo).PropertyType == typeof(IChannel))) continue; - GetVectorAndKind(fieldInfo, out isVector, out kind); + GetVectorAndKind(memberInfo, out isVector, out kind); } else { @@ -268,7 +283,7 @@ public static InternalSchemaDefinition Create(Type userType, SchemaDefinition us dstCols[i] = col.IsComputed ? new Column(colName, colType, col.Generator, col.Metadata) - : new Column(colName, colType, fieldInfo, col.Metadata); + : new Column(colName, colType, memberInfo, col.Metadata); } return new InternalSchemaDefinition(dstCols); diff --git a/src/Microsoft.ML.Api/SchemaDefinition.cs b/src/Microsoft.ML.Api/SchemaDefinition.cs index e08845a87e..3258df4ffd 100644 --- a/src/Microsoft.ML.Api/SchemaDefinition.cs +++ b/src/Microsoft.ML.Api/SchemaDefinition.cs @@ -14,7 +14,7 @@ namespace Microsoft.ML.Runtime.Api /// /// Attach to a member of a class to indicate that the item type should be of class key. /// - [AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = true)] + [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)] public sealed class KeyTypeAttribute : Attribute { // REVIEW: Property based, but should I just have a constructor? @@ -46,7 +46,7 @@ public KeyTypeAttribute() /// Allows a member to be marked as a vector valued field, primarily allowing one to set /// the dimensionality of the resulting array. /// - [AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = true)] + [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)] public sealed class VectorTypeAttribute : Attribute { private readonly int[] _dims; @@ -66,7 +66,7 @@ public VectorTypeAttribute(params int[] dims) /// Describes column information such as name and the source columns indicies that this /// column encapsulates. /// - [AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = true)] + [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)] public sealed class ColumnAttribute : Attribute { public ColumnAttribute(string ordinal, string name = null) @@ -97,7 +97,7 @@ public ColumnAttribute(string ordinal, string name = null) /// Allows a member to specify its column name directly, as opposed to the default /// behavior of using the member name as the column name. /// - [AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = true)] + [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)] public sealed class ColumnNameAttribute : Attribute { private readonly string _name; @@ -119,7 +119,7 @@ public ColumnNameAttribute(string name) /// /// Mark this member as not being exposed as a column in the schema. /// - [AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = true)] + [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)] public sealed class NoColumnAttribute : Attribute { } @@ -128,7 +128,7 @@ public sealed class NoColumnAttribute : Attribute /// Mark a member that implements exactly IChannel as being permitted to receive /// channel information from an external channel. /// - [AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = true)] + [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)] public sealed class CursorChannelAttribute : Attribute { /// @@ -158,19 +158,40 @@ public static bool TrySetCursorChannel(IExceptionContext ectx, T obj, IChanne .Where(x => x.GetCustomAttributes(typeof(CursorChannelAttribute), false).Any()) .ToArray(); + var cursorChannelAttrProperties = typeof(T) + .GetProperties(BindingFlags.Public | BindingFlags.Instance) + .Where(x => x.CanRead && x.CanWrite && x.GetGetMethod() != null && x.GetSetMethod() != null && x.GetIndexParameters().Length == 0) + .Where(x => x.GetCustomAttributes(typeof(CursorChannelAttribute), false).Any()); + + var cursorChannelAttrMembers = (cursorChannelAttrFields as IEnumerable).Concat(cursorChannelAttrProperties).ToArray(); + //Check that there is at most one such field. - if (cursorChannelAttrFields.Length == 0) + if (cursorChannelAttrMembers.Length == 0) return false; - ectx.Check(cursorChannelAttrFields.Length == 1, - "Only one field with CursorChannel attribute is allowed."); + ectx.Check(cursorChannelAttrMembers.Length == 1, + "Only one public field or property with CursorChannel attribute is allowed."); //Check that the marked field has type IChannel. - var cursorChannelFieldInfo = cursorChannelAttrFields[0]; - ectx.Check(cursorChannelFieldInfo.FieldType == typeof(IChannel), - "Field marked as CursorChannel must have type IChannel."); - - cursorChannelFieldInfo.SetValue(obj, channel); + var cursorChannelAttrMemberInfo = cursorChannelAttrMembers[0]; + switch (cursorChannelAttrMemberInfo) + { + case FieldInfo cursorChannelAttrFieldInfo: + ectx.Check(cursorChannelAttrFieldInfo.FieldType == typeof(IChannel), + "Field marked as CursorChannel must have type IChannel."); + cursorChannelAttrFieldInfo.SetValue(obj, channel); + break; + + case PropertyInfo cursorChannelAttrPropertyInfo: + ectx.Check(cursorChannelAttrPropertyInfo.PropertyType == typeof(IChannel), + "Property marked as CursorChannel must have type IChannel."); + cursorChannelAttrPropertyInfo.SetValue(obj, channel); + break; + + default: + Contracts.Assert(false); + throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo"); + } return true; } } @@ -319,37 +340,63 @@ public static SchemaDefinition Create(Type userType) SchemaDefinition cols = new SchemaDefinition(); HashSet colNames = new HashSet(); - foreach (var fieldInfo in userType.GetFields()) + + var fieldInfos = userType.GetFields(BindingFlags.Public | BindingFlags.Instance); + var propertyInfos = + userType + .GetProperties(BindingFlags.Public | BindingFlags.Instance) + .Where(x => x.CanRead && x.CanWrite && x.GetGetMethod() != null && x.GetSetMethod() != null && x.GetIndexParameters().Length == 0); + + var memberInfos = (fieldInfos as IEnumerable).Concat(propertyInfos).ToArray(); + + foreach (var memberInfo in memberInfos) { // Clause to handle the field that may be used to expose the cursor channel. // This field does not need a column. // REVIEW: maybe validate the channel attribute now, instead // of later at cursor creation. - if (fieldInfo.FieldType == typeof(IChannel)) - continue; - // Const fields do not need to be mapped. - if (fieldInfo.IsLiteral) - continue; + switch (memberInfo) + { + case FieldInfo fieldInfo: + if (fieldInfo.FieldType == typeof(IChannel)) + continue; + + // Const fields do not need to be mapped. + if (fieldInfo.IsLiteral) + continue; + + break; - if (fieldInfo.GetCustomAttribute() != null) + case PropertyInfo propertyInfo: + if (propertyInfo.PropertyType == typeof(IChannel)) + continue; + break; + + default: + Contracts.Assert(false); + throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo"); + } + + if (memberInfo.GetCustomAttribute() != null) continue; - var mappingAttr = fieldInfo.GetCustomAttribute(); - var mappingNameAttr = fieldInfo.GetCustomAttribute(); - string name = mappingAttr?.Name ?? mappingNameAttr?.Name ?? fieldInfo.Name; + + var mappingAttr = memberInfo.GetCustomAttribute(); + var mappingNameAttr = memberInfo.GetCustomAttribute(); + string name = mappingAttr?.Name ?? mappingNameAttr?.Name ?? memberInfo.Name; // Disallow duplicate names, because the field enumeration order is not actually // well defined, so we are not gauranteed to have consistent "hiding" from run to // run, across different .NET versions. if (!colNames.Add(name)) throw Contracts.ExceptParam(nameof(userType), "Duplicate column name '{0}' detected, this is disallowed", name); - InternalSchemaDefinition.GetVectorAndKind(fieldInfo, out bool isVector, out DataKind kind); + InternalSchemaDefinition.GetVectorAndKind(memberInfo, out bool isVector, out DataKind kind); PrimitiveType itemType; - var keyAttr = fieldInfo.GetCustomAttribute(); + var keyAttr = memberInfo.GetCustomAttribute(); if (keyAttr != null) { if (!KeyType.IsValidDataKind(kind)) - throw Contracts.ExceptParam(nameof(userType), "Member {0} marked with KeyType attribute, but does not appear to be a valid kind of data for a key type", fieldInfo.Name); + throw Contracts.ExceptParam(nameof(userType), "Member {0} marked with KeyType attribute, but does not appear to be a valid kind of data for a key type", memberInfo.Name); itemType = new KeyType(kind, keyAttr.Min, keyAttr.Count, keyAttr.Contiguous); } else @@ -357,9 +404,9 @@ public static SchemaDefinition Create(Type userType) // Get the column type. ColumnType columnType; - var vectorAttr = fieldInfo.GetCustomAttribute(); + var vectorAttr = memberInfo.GetCustomAttribute(); if (vectorAttr != null && !isVector) - throw Contracts.ExceptParam(nameof(userType), "Member {0} marked with VectorType attribute, but does not appear to be a vector type", fieldInfo.Name); + throw Contracts.ExceptParam(nameof(userType), "Member {0} marked with VectorType attribute, but does not appear to be a vector type", memberInfo.Name); if (isVector) { int[] dims = vectorAttr?.Dims; @@ -373,7 +420,7 @@ public static SchemaDefinition Create(Type userType) else columnType = itemType; - cols.Add(new Column() { MemberName = fieldInfo.Name, ColumnName = name, ColumnType = columnType }); + cols.Add(new Column() { MemberName = memberInfo.Name, ColumnName = name, ColumnType = columnType }); } return cols; } diff --git a/src/Microsoft.ML.Api/TypedCursor.cs b/src/Microsoft.ML.Api/TypedCursor.cs index cd8198e14d..19f9a7cf72 100644 --- a/src/Microsoft.ML.Api/TypedCursor.cs +++ b/src/Microsoft.ML.Api/TypedCursor.cs @@ -103,11 +103,11 @@ private TypedCursorable(IHostEnvironment env, IDataView data, bool ignoreMissing throw _host.Except("Column '{0}' not found in the data view", col.ColumnName); } var realColType = _data.Schema.GetColumnType(colIndex); - if (!IsCompatibleType(realColType, col.FieldInfo)) + if (!IsCompatibleType(realColType, col.MemberInfo)) { throw _host.Except( - "Can't bind the IDataView column '{0}' of type '{1}' to field '{2}' of type '{3}'.", - col.ColumnName, realColType, col.FieldInfo.Name, col.FieldInfo.FieldType.FullName); + "Can't bind the IDataView column '{0}' of type '{1}' to field or property '{2}' of type '{3}'.", + col.ColumnName, realColType, col.MemberInfo.Name, col.FieldOrPropertyType.FullName); } acceptedCols.Add(col); @@ -130,14 +130,12 @@ private TypedCursorable(IHostEnvironment env, IDataView data, bool ignoreMissing } /// - /// Returns whether the column type can be bound to field . + /// Returns whether the column type can be bound to field . /// They must both be vectors or scalars, and the raw data kind should match. /// - private static bool IsCompatibleType(ColumnType colType, FieldInfo fieldInfo) + private static bool IsCompatibleType(ColumnType colType, MemberInfo memberInfo) { - bool isVector; - DataKind kind; - InternalSchemaDefinition.GetVectorAndKind(fieldInfo, out isVector, out kind); + InternalSchemaDefinition.GetVectorAndKind(memberInfo, out bool isVector, out DataKind kind); if (isVector) return colType.IsVector && colType.ItemType.RawKind == kind; else @@ -269,8 +267,7 @@ public ValueGetter GetIdGetter() private Action GenerateSetter(IRow input, int index, InternalSchemaDefinition.Column column, Delegate poke, Delegate peek) { var colType = input.Schema.GetColumnType(index); - var fieldInfo = column.FieldInfo; - var fieldType = fieldInfo.FieldType; + var fieldType = column.OutputType; var genericType = fieldType; Func> del; if (fieldType.IsArray) @@ -431,7 +428,7 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit else { // REVIEW: Is this even possible? - throw Ch.ExceptNotImpl("Type '{0}' is not yet supported.", fieldInfo.FieldType.FullName); + throw Ch.ExceptNotImpl("Type '{0}' is not yet supported.", column.OutputType.FullName); } MethodInfo meth = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(genericType); return (Action)meth.Invoke(this, new object[] { input, index, poke, peek }); diff --git a/src/Microsoft.ML/Data/TextLoader.cs b/src/Microsoft.ML/Data/TextLoader.cs index 6e89e8a54e..330412185e 100644 --- a/src/Microsoft.ML/Data/TextLoader.cs +++ b/src/Microsoft.ML/Data/TextLoader.cs @@ -7,6 +7,7 @@ using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; using System; +using System.Collections.Generic; using System.Linq; using System.Reflection; using System.Text.RegularExpressions; @@ -71,20 +72,30 @@ public TextLoader CreateFrom(bool useHeader = false, char separator = '\t', bool allowQuotedStrings = true, bool supportSparse = true, bool trimWhitespace = false) { - var fields = typeof(TInput).GetFields(); - Arguments.Column = new TextLoaderColumn[fields.Length]; - for (int index = 0; index < fields.Length; index++) + var userType = typeof(TInput); + + var fieldInfos = userType.GetFields(BindingFlags.Public | BindingFlags.Instance); + + var propertyInfos = + userType + .GetProperties(BindingFlags.Public | BindingFlags.Instance) + .Where(x => x.CanRead && x.CanWrite && x.GetGetMethod() != null && x.GetSetMethod() != null && x.GetIndexParameters().Length == 0); + + var memberInfos = (fieldInfos as IEnumerable).Concat(propertyInfos).ToArray(); + + Arguments.Column = new TextLoaderColumn[memberInfos.Length]; + for (int index = 0; index < memberInfos.Length; index++) { - var field = fields[index]; - var mappingAttr = field.GetCustomAttribute(); + var memberInfo = memberInfos[index]; + var mappingAttr = memberInfo.GetCustomAttribute(); if (mappingAttr == null) - throw Contracts.Except($"{field.Name} is missing ColumnAttribute"); + throw Contracts.Except($"Field or property {memberInfo.Name} is missing ColumnAttribute"); if (Regex.Match(mappingAttr.Ordinal, @"[^(0-9,\*\-~)]+").Success) throw Contracts.Except($"{mappingAttr.Ordinal} contains invalid characters. " + $"Valid characters are 0-9, *, - and ~"); - var name = mappingAttr.Name ?? field.Name; + var name = mappingAttr.Name ?? memberInfo.Name; Runtime.Data.TextLoader.Range[] sources; if (!Runtime.Data.TextLoader.Column.TryParseSourceEx(mappingAttr.Ordinal, out sources)) @@ -96,8 +107,23 @@ public TextLoader CreateFrom(bool useHeader = false, tlc.Name = name; tlc.Source = new TextLoaderRange[sources.Length]; DataKind dk; - if (!TryGetDataKind(field.FieldType.IsArray ? field.FieldType.GetElementType() : field.FieldType, out dk)) - throw Contracts.Except($"{name} is of unsupported type."); + switch (memberInfo) + { + case FieldInfo field: + if (!TryGetDataKind(field.FieldType.IsArray ? field.FieldType.GetElementType() : field.FieldType, out dk)) + throw Contracts.Except($"Field {name} is of unsupported type."); + + break; + + case PropertyInfo property: + if (!TryGetDataKind(property.PropertyType.IsArray ? property.PropertyType.GetElementType() : property.PropertyType, out dk)) + throw Contracts.Except($"Property {name} is of unsupported type."); + break; + + default: + Contracts.Assert(false); + throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo"); + } tlc.Type = dk; diff --git a/test/Microsoft.ML.FSharp.Tests/Microsoft.ML.FSharp.Tests.fsproj b/test/Microsoft.ML.FSharp.Tests/Microsoft.ML.FSharp.Tests.fsproj new file mode 100644 index 0000000000..888a983e51 --- /dev/null +++ b/test/Microsoft.ML.FSharp.Tests/Microsoft.ML.FSharp.Tests.fsproj @@ -0,0 +1,52 @@ + + + + netcoreapp2.0 + $(TargetFrameworks); net461 + 2003;$(NoWarn) + false + + x64 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/test/Microsoft.ML.FSharp.Tests/Program.fs b/test/Microsoft.ML.FSharp.Tests/Program.fs new file mode 100644 index 0000000000..f45e4e3c6c --- /dev/null +++ b/test/Microsoft.ML.FSharp.Tests/Program.fs @@ -0,0 +1,9 @@ +namespace Microsoft.ML.FSharp.Tests + +#if NETCOREAPP2_0 +module Program = + + [] + let main _ = 0 +#endif + diff --git a/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs new file mode 100644 index 0000000000..570c1e0722 --- /dev/null +++ b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs @@ -0,0 +1,263 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + + +//================================================================================================= +// This test can be run either as a compiled test with .NET Core (on any platform) or +// manually in script form (to help debug it and also check that F# scripting works with ML.NET). +// Running as a script requires using F# Interactive on Windows, and the explicit references below. +// The references would normally be created by a package loader for the scripting +// environment, e.g. see https://github.com/isaacabraham/ml-test-experiment/, but +// here we list them explicitly to avoid the dependency on a package loader, +// +// You should build Microsoft.ML.FSharp.Tests in Debug mode for framework net461 +// before running this as a script with F# Interactive by editing the project +// file to have: +// netcoreapp2.0; net461 + +#if INTERACTIVE +#r "netstandard" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.Core.dll" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Google.Protobuf.dll" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Newtonsoft.Json.dll" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/System.CodeDom.dll" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/System.Threading.Tasks.Dataflow.dll" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.CpuMath.dll" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.Data.dll" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.Transforms.dll" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.ResultProcessor.dll" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.PCA.dll" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.KMeansClustering.dll" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.FastTree.dll" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.Api.dll" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.Sweeper.dll" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.dll" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.StandardLearners.dll" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.PipelineInference.dll" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/xunit.core.dll" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/xunit.assert.dll" +#r "System" +#r "System.ComponentModel.Composition" +#r "System.Core" +#r "System.Xml.Linq" + +// Later tests will add data import using F# type providers: +//#r @"../../packages/fsharp.data/3.0.0-beta4/lib/netstandard2.0/FSharp.Data.dll" // this must be referenced from its package location + +#endif + +//================================================================================ +// The tests proper start here + +#if !INTERACTIVE +namespace Microsoft.ML.FSharp.Tests +#endif + +open System +open Microsoft.ML +open Microsoft.ML.Data +open Microsoft.ML.Transforms +open Microsoft.ML.Trainers +open Microsoft.ML.Runtime.Api +open Xunit + +module SmokeTest1 = + + type SentimentData() = + [] + val mutable SentimentText : string + [] + val mutable Sentiment : float32 + + type SentimentPrediction() = + [] + val mutable Sentiment : bool + + [] + let ``FSharp-Sentiment-Smoke-Test`` () = + + // See https://github.com/dotnet/machinelearning/issues/401: forces the loading of ML.NET component assemblies + let _load = + [ typeof; + typeof ] + + let testDataPath = __SOURCE_DIRECTORY__ + @"/../data/wikipedia-detox-250-line-data.tsv" + + let pipeline = LearningPipeline() + + pipeline.Add( + TextLoader(testDataPath).CreateFrom( + Arguments = + TextLoaderArguments( + HasHeader = true, + Column = [| TextLoaderColumn(Name = "Label", + Source = [| TextLoaderRange(0) |], + Type = Nullable (Data.DataKind.Num)) + TextLoaderColumn(Name = "SentimentText", + Source = [| TextLoaderRange(1) |], + Type = Nullable (Data.DataKind.Text)) |] + ))) + + pipeline.Add( + TextFeaturizer( + "Features", [| "SentimentText" |], + KeepDiacritics = false, + KeepPunctuations = false, + TextCase = TextNormalizerTransformCaseNormalizationMode.Lower, + OutputTokens = true, + VectorNormalizer = TextTransformTextNormKind.L2 + )) + + pipeline.Add( + FastTreeBinaryClassifier( + NumLeaves = 5, + NumTrees = 5, + MinDocumentsInLeafs = 2 + )) + + let model = pipeline.Train() + + let predictions = + [ SentimentData(SentimentText = "This is a gross exaggeration. Nobody is setting a kangaroo court. There was a simple addition.") + SentimentData(SentimentText = "Sort of ok") + SentimentData(SentimentText = "Joe versus the Volcano Coffee Company is a great film.") ] + |> model.Predict + + let predictionResults = [ for p in predictions -> p.Sentiment ] + Assert.Equal(predictionResults, [ false; true; true ]) + +module SmokeTest2 = + + [] + type SentimentData = + { [] + SentimentText : string + + [] + Sentiment : float32 } + + [] + type SentimentPrediction = + { [] + Sentiment : bool } + + [] + let ``FSharp-Sentiment-Smoke-Test`` () = + + // See https://github.com/dotnet/machinelearning/issues/401: forces the loading of ML.NET component assemblies + let _load = + [ typeof; + typeof ] + + let testDataPath = __SOURCE_DIRECTORY__ + @"/../data/wikipedia-detox-250-line-data.tsv" + + let pipeline = LearningPipeline() + + pipeline.Add( + TextLoader(testDataPath).CreateFrom( + Arguments = + TextLoaderArguments( + HasHeader = true, + Column = [| TextLoaderColumn(Name = "Label", + Source = [| TextLoaderRange(0) |], + Type = Nullable (Data.DataKind.Num)) + TextLoaderColumn(Name = "SentimentText", + Source = [| TextLoaderRange(1) |], + Type = Nullable (Data.DataKind.Text)) |] + ))) + + pipeline.Add( + TextFeaturizer( + "Features", [| "SentimentText" |], + KeepDiacritics = false, + KeepPunctuations = false, + TextCase = TextNormalizerTransformCaseNormalizationMode.Lower, + OutputTokens = true, + VectorNormalizer = TextTransformTextNormKind.L2 + )) + + pipeline.Add( + FastTreeBinaryClassifier( + NumLeaves = 5, + NumTrees = 5, + MinDocumentsInLeafs = 2 + )) + + let model = pipeline.Train() + + let predictions = + [ { SentimentText = "This is a gross exaggeration. Nobody is setting a kangaroo court. There was a simple addition."; Sentiment = 0.0f } + { SentimentText = "Sort of ok"; Sentiment = 0.0f } + { SentimentText = "Joe versus the Volcano Coffee Company is a great film."; Sentiment = 0.0f } ] + |> model.Predict + + let predictionResults = [ for p in predictions -> p.Sentiment ] + Assert.Equal(predictionResults, [ false; true; true ]) + +module SmokeTest3 = + + type SentimentData() = + [] + member val SentimentText = "" with get, set + + [] + member val Sentiment = 0.0 with get, set + + type SentimentPrediction() = + [] + member val Sentiment = false with get, set + + [] + let ``FSharp-Sentiment-Smoke-Test`` () = + + // See https://github.com/dotnet/machinelearning/issues/401: forces the loading of ML.NET component assemblies + let _load = + [ typeof; + typeof ] + + let testDataPath = __SOURCE_DIRECTORY__ + @"/../data/wikipedia-detox-250-line-data.tsv" + + let pipeline = LearningPipeline() + + pipeline.Add( + TextLoader(testDataPath).CreateFrom( + Arguments = + TextLoaderArguments( + HasHeader = true, + Column = [| TextLoaderColumn(Name = "Label", + Source = [| TextLoaderRange(0) |], + Type = Nullable (Data.DataKind.Num)) + TextLoaderColumn(Name = "SentimentText", + Source = [| TextLoaderRange(1) |], + Type = Nullable (Data.DataKind.Text)) |] + ))) + + pipeline.Add( + TextFeaturizer( + "Features", [| "SentimentText" |], + KeepDiacritics = false, + KeepPunctuations = false, + TextCase = TextNormalizerTransformCaseNormalizationMode.Lower, + OutputTokens = true, + VectorNormalizer = TextTransformTextNormKind.L2 + )) + + pipeline.Add( + FastTreeBinaryClassifier( + NumLeaves = 5, + NumTrees = 5, + MinDocumentsInLeafs = 2 + )) + + let model = pipeline.Train() + + let predictions = + [ SentimentData(SentimentText = "This is a gross exaggeration. Nobody is setting a kangaroo court. There was a simple addition.") + SentimentData(SentimentText = "Sort of ok") + SentimentData(SentimentText = "Joe versus the Volcano Coffee Company is a great film.") ] + |> model.Predict + + let predictionResults = [ for p in predictions -> p.Sentiment ] + Assert.Equal(predictionResults, [ false; true; true ]) + diff --git a/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs b/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs index 87e23952d6..14a7f473f7 100644 --- a/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs +++ b/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs @@ -174,6 +174,49 @@ public void CanTrain() } + [Fact] + public void CanTrainProperties() + { + var pipeline = new LearningPipeline(); + var data = new List() { + new IrisDataProperties { SepalLength = 1f, SepalWidth = 1f, PetalLength=0.3f, PetalWidth=5.1f, Label=1}, + new IrisDataProperties { SepalLength = 1f, SepalWidth = 1f, PetalLength=0.3f, PetalWidth=5.1f, Label=1}, + new IrisDataProperties { SepalLength = 1.2f, SepalWidth = 0.5f, PetalLength=0.3f, PetalWidth=5.1f, Label=0} + }; + var collection = CollectionDataSource.Create(data); + + pipeline.Add(collection); + pipeline.Add(new ColumnConcatenator(outputColumn: "Features", + "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")); + pipeline.Add(new StochasticDualCoordinateAscentClassifier()); + PredictionModel model = pipeline.Train(); + + IrisPredictionProperties prediction = model.Predict(new IrisDataProperties() + { + SepalLength = 3.3f, + SepalWidth = 1.6f, + PetalLength = 0.2f, + PetalWidth = 5.1f, + }); + + pipeline = new LearningPipeline(); + collection = CollectionDataSource.Create(data.AsEnumerable()); + pipeline.Add(collection); + pipeline.Add(new ColumnConcatenator(outputColumn: "Features", + "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")); + pipeline.Add(new StochasticDualCoordinateAscentClassifier()); + model = pipeline.Train(); + + prediction = model.Predict(new IrisDataProperties() + { + SepalLength = 3.3f, + SepalWidth = 1.6f, + PetalLength = 0.2f, + PetalWidth = 5.1f, + }); + + } + public class Input { [Column("0")] @@ -207,6 +250,37 @@ public class IrisPrediction public float[] PredictedLabels; } + public class IrisDataProperties + { + private float _Label; + private float _SepalLength; + private float _SepalWidth; + private float _PetalLength; + private float _PetalWidth; + + [Column("0")] + public float Label { get { return _Label; } set { _Label = value; } } + + [Column("1")] + public float SepalLength { get { return _SepalLength; } set { _SepalLength = value; } } + + [Column("2")] + public float SepalWidth { get { return _SepalWidth; } set { _SepalWidth = value; } } + + [Column("3")] + public float PetalLength { get { return _PetalLength; } set { _PetalLength = value; } } + + [Column("4")] + public float PetalWidth { get { return _PetalWidth; } set { _PetalWidth = value; } } + } + + public class IrisPredictionProperties + { + private float[] _PredictedLabels; + [ColumnName("Score")] + public float[] PredictedLabels { get { return _PredictedLabels; } set { _PredictedLabels = value; } } + } + public class ConversionSimpleClass { public int fInt; @@ -257,7 +331,7 @@ public bool CompareObjectValues(object x, object y, Type type) public bool CompareThroughReflection(T x, T y) { - foreach (var field in typeof(T).GetFields()) + foreach (var field in typeof(T).GetFields(BindingFlags.Public | BindingFlags.Instance)) { var xvalue = field.GetValue(x); var yvalue = field.GetValue(y); @@ -272,6 +346,25 @@ public bool CompareThroughReflection(T x, T y) return false; } } + foreach (var property in typeof(T).GetProperties(BindingFlags.Public | BindingFlags.Instance)) + { + // Don't compare properties with private getters and setters + if (!property.CanRead || !property.CanWrite || property.GetGetMethod() == null || property.GetSetMethod() == null) + continue; + + var xvalue = property.GetValue(x); + var yvalue = property.GetValue(y); + if (property.PropertyType.IsArray) + { + if (!CompareArrayValues(xvalue as Array, yvalue as Array)) + return false; + } + else + { + if (!CompareObjectValues(xvalue, yvalue, property.PropertyType)) + return false; + } + } return true; } @@ -288,14 +381,6 @@ public bool CompareArrayValues(Array x, Array y) return true; } - public class ClassWithConstField - { - public const string ConstString = "N"; - public string fString; - public const int ConstInt = 100; - public int fInt; - } - [Fact] public void RoundTripConversionWithBasicTypes() { @@ -489,6 +574,50 @@ public void ConversionMinValueToNullBehavior() } } + public class ConversionLossMinValueClassProperties + { + private int? _fInt; + private long? _fLong; + private short? _fShort; + private sbyte? _fsByte; + public int? IntProp { get { return _fInt; } set { _fInt = value; } } + public short? ShortProp { get { return _fShort; } set { _fShort = value; } } + public sbyte? SByteProp { get { return _fsByte; } set { _fsByte = value; } } + public long? LongProp { get { return _fLong; } set { _fLong = value; } } + } + + [Fact] + public void ConversionMinValueToNullBehaviorProperties() + { + using (var env = new TlcEnvironment()) + { + + var data = new List + { + new ConversionLossMinValueClassProperties() { SByteProp = null, IntProp = null, LongProp = null, ShortProp = null }, + new ConversionLossMinValueClassProperties() { SByteProp = sbyte.MinValue, IntProp = int.MinValue, LongProp = long.MinValue, ShortProp = short.MinValue } + }; + foreach (var field in typeof(ConversionLossMinValueClassProperties).GetFields()) + { + var dataView = ComponentCreation.CreateDataView(env, data); + var enumerator = dataView.AsEnumerable(env, false).GetEnumerator(); + while (enumerator.MoveNext()) + { + Assert.True(enumerator.Current.IntProp == null && enumerator.Current.LongProp == null && + enumerator.Current.SByteProp == null && enumerator.Current.ShortProp == null); + } + } + } + } + + public class ClassWithConstField + { + public const string ConstString = "N"; + public string fString; + public const int ConstInt = 100; + public int fInt; + } + [Fact] public void ClassWithConstFieldsConversion() { @@ -510,6 +639,122 @@ public void ClassWithConstFieldsConversion() } } + + public class ClassWithMixOfFieldsAndProperties + { + public string fString; + private int _fInt; + public int IntProp { get { return _fInt; } set { _fInt = value; } } + } + + [Fact] + public void ClassWithMixOfFieldsAndPropertiesConversion() + { + var data = new List() + { + new ClassWithMixOfFieldsAndProperties(){ IntProp=1, fString ="lala" }, + new ClassWithMixOfFieldsAndProperties(){ IntProp=-1, fString ="" }, + new ClassWithMixOfFieldsAndProperties(){ IntProp=0, fString =null } + }; + + using (var env = new TlcEnvironment()) + { + var dataView = ComponentCreation.CreateDataView(env, data); + var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator(); + var originalEnumerator = data.GetEnumerator(); + while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext()) + Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current)); + Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); + } + } + + public abstract class BaseClassWithInheritedProperties + { + private string _fString; + private byte _fByte; + public string StringProp { get { return _fString; } set { _fString = value; } } + public abstract long LongProp { get; set; } + public virtual byte ByteProp { get { return _fByte; } set { _fByte = value; } } + } + + + public class ClassWithPrivateFieldsAndProperties + { + public ClassWithPrivateFieldsAndProperties() { seq++; _unusedStaticField++; _unusedPrivateField1 = 100; } + static public int seq; + static public int _unusedStaticField; + private int _unusedPrivateField1; + private string _fString; + + // This property is ignored because it has no setter + private int UnusedReadOnlyProperty { get { return _unusedPrivateField1; } } + + // This property is ignored because it is private + private int UnusedPrivateProperty { get { return _unusedPrivateField1; } set { _unusedPrivateField1 = value; } } + + // This property is ignored because it has a private setter + public int UnusedPropertyWithPrivateSetter { get { return _unusedPrivateField1; } private set { _unusedPrivateField1 = value; } } + + // This property is ignored because it has a private getter + public int UnusedPropertyWithPrivateGetter { private get { return _unusedPrivateField1; } set { _unusedPrivateField1 = value; } } + + public string StringProp { get { return _fString; } set { _fString = value; } } + } + + [Fact] + public void ClassWithPrivateFieldsAndPropertiesConversion() + { + var data = new List() + { + new ClassWithPrivateFieldsAndProperties(){ StringProp ="lala" }, + new ClassWithPrivateFieldsAndProperties(){ StringProp ="baba" } + }; + + using (var env = new TlcEnvironment()) + { + var dataView = ComponentCreation.CreateDataView(env, data); + var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator(); + var originalEnumerator = data.GetEnumerator(); + while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext()) + { + Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current)); + Assert.True(enumeratorSimple.Current.UnusedPropertyWithPrivateSetter == 100); + } + Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); + } + } + + public class ClassWithInheritedProperties : BaseClassWithInheritedProperties + { + private int _fInt; + private long _fLong; + private byte _fByte2; + public int IntProp { get { return _fInt; } set { _fInt = value; } } + public override long LongProp { get { return _fLong; } set { _fLong = value; } } + public override byte ByteProp { get { return _fByte2; } set { _fByte2 = value; } } + } + + [Fact] + public void ClassWithInheritedPropertiesConversion() + { + var data = new List() + { + new ClassWithInheritedProperties(){ IntProp=1, StringProp ="lala", LongProp=17, ByteProp=3 }, + new ClassWithInheritedProperties(){ IntProp=-1, StringProp ="", LongProp=2, ByteProp=4 }, + new ClassWithInheritedProperties(){ IntProp=0, StringProp =null, LongProp=18, ByteProp=5 } + }; + + using (var env = new TlcEnvironment()) + { + var dataView = ComponentCreation.CreateDataView(env, data); + var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator(); + var originalEnumerator = data.GetEnumerator(); + while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext()) + Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current)); + Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); + } + } + public class ClassWithArrays { public string[] fString; @@ -609,5 +854,129 @@ public void RoundTripConversionWithArrays() Assert.True(!enumeratorNullable.MoveNext() && !originalNullalbleEnumerator.MoveNext()); } } + public class ClassWithArrayProperties + { + private string[] _fString; + private int[] _fInt; + private uint[] _fuInt; + private short[] _fShort; + private ushort[] _fuShort; + private sbyte[] _fsByte; + private byte[] _fByte; + private long[] _fLong; + private ulong[] _fuLong; + private float[] _fFloat; + private double[] _fDouble; + private bool[] _fBool; + public string[] StringProp { get { return _fString; } set { _fString = value; } } + public int[] IntProp { get { return _fInt; } set { _fInt = value; } } + public uint[] UIntProp { get { return _fuInt; } set { _fuInt = value; } } + public short[] ShortProp { get { return _fShort; } set { _fShort = value; } } + public ushort[] UShortProp { get { return _fuShort; } set { _fuShort = value; } } + public sbyte[] SByteProp { get { return _fsByte; } set { _fsByte = value; } } + public byte[] ByteProp { get { return _fByte; } set { _fByte = value; } } + public long[] LongProp { get { return _fLong; } set { _fLong = value; } } + public ulong[] ULongProp { get { return _fuLong; } set { _fuLong = value; } } + public float[] FloatProp { get { return _fFloat; } set { _fFloat = value; } } + public double[] DobuleProp { get { return _fDouble; } set { _fDouble = value; } } + public bool[] BoolProp { get { return _fBool; } set { _fBool = value; } } + } + + public class ClassWithNullableArrayProperties + { + private string[] _fString; + private int?[] _fInt; + private uint?[] _fuInt; + private short?[] _fShort; + private ushort?[] _fuShort; + private sbyte?[] _fsByte; + private byte?[] _fByte; + private long?[] _fLong; + private ulong?[] _fuLong; + private float?[] _fFloat; + private double?[] _fDouble; + private bool?[] _fBool; + + public string[] StringProp { get { return _fString; } set { _fString = value; } } + public int?[] IntProp { get { return _fInt; } set { _fInt = value; } } + public uint?[] UIntProp { get { return _fuInt; } set { _fuInt = value; } } + public short?[] ShortProp { get { return _fShort; } set { _fShort = value; } } + public ushort?[] UShortProp { get { return _fuShort; } set { _fuShort = value; } } + public sbyte?[] SByteProp { get { return _fsByte; } set { _fsByte = value; } } + public byte?[] ByteProp { get { return _fByte; } set { _fByte = value; } } + public long?[] LongProp { get { return _fLong; } set { _fLong = value; } } + public ulong?[] ULongProp { get { return _fuLong; } set { _fuLong = value; } } + public float?[] SingleProp { get { return _fFloat; } set { _fFloat = value; } } + public double?[] DoubleProp { get { return _fDouble; } set { _fDouble = value; } } + public bool?[] BoolProp { get { return _fBool; } set { _fBool = value; } } + } + + [Fact] + public void RoundTripConversionWithArrayPropertiess() + { + + var data = new List + { + new ClassWithArrayProperties() + { + IntProp = new int[3] { 0, 1, 2 }, + FloatProp = new float[3] { -0.99f, 0f, 0.99f }, + StringProp = new string[2] { "hola", "lola" }, + BoolProp = new bool[2] { true, false }, + ByteProp = new byte[3] { 0, 124, 255 }, + DobuleProp = new double[3] { -1, 0, 1 }, + LongProp = new long[] { 0, 1, 2 }, + SByteProp = new sbyte[3] { -127, 127, 0 }, + ShortProp = new short[3] { 0, 1225, 32767 }, + UIntProp = new uint[2] { 0, uint.MaxValue }, + ULongProp = new ulong[2] { ulong.MaxValue, 0 }, + UShortProp = new ushort[2] { 0, ushort.MaxValue } + }, + new ClassWithArrayProperties() { IntProp = new int[3] { -2, 1, 0 }, FloatProp = new float[3] { 0.99f, 0f, -0.99f }, StringProp = new string[2] { "", null } }, + new ClassWithArrayProperties() + }; + + var nullableData = new List + { + new ClassWithNullableArrayProperties() + { + IntProp = new int?[3] { null, -1, 1 }, + SingleProp = new float?[3] { -0.99f, null, 0.99f }, + StringProp = new string[2] { null, "" }, + BoolProp = new bool?[3] { true, null, false }, + ByteProp = new byte?[4] { 0, 125, null, 255 }, + DoubleProp = new double?[3] { -1, null, 1 }, + LongProp = new long?[] { null, -1, 1 }, + SByteProp = new sbyte?[3] { -127, 127, null }, + ShortProp = new short?[3] { 0, null, 32767 }, + UIntProp = new uint?[4] { null, 42, 0, uint.MaxValue }, + ULongProp = new ulong?[3] { ulong.MaxValue, null, 0 }, + UShortProp = new ushort?[3] { 0, null, ushort.MaxValue } + }, + new ClassWithNullableArrayProperties() { IntProp = new int?[3] { -2, 1, 0 }, SingleProp = new float?[3] { 0.99f, 0f, -0.99f }, StringProp = new string[2] { "lola", "hola" } }, + new ClassWithNullableArrayProperties() + }; + + using (var env = new TlcEnvironment()) + { + var dataView = ComponentCreation.CreateDataView(env, data); + var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator(); + var originalEnumerator = data.GetEnumerator(); + while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext()) + { + Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current)); + } + Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); + + var nullableDataView = ComponentCreation.CreateDataView(env, nullableData); + var enumeratorNullable = nullableDataView.AsEnumerable(env, false).GetEnumerator(); + var originalNullalbleEnumerator = nullableData.GetEnumerator(); + while (enumeratorNullable.MoveNext() && originalNullalbleEnumerator.MoveNext()) + { + Assert.True(CompareThroughReflection(enumeratorNullable.Current, originalNullalbleEnumerator.Current)); + } + Assert.True(!enumeratorNullable.MoveNext() && !originalNullalbleEnumerator.MoveNext()); + } + } } } diff --git a/test/Microsoft.ML.Tests/TextLoaderTests.cs b/test/Microsoft.ML.Tests/TextLoaderTests.cs index 61a1744dfb..50a7e55975 100644 --- a/test/Microsoft.ML.Tests/TextLoaderTests.cs +++ b/test/Microsoft.ML.Tests/TextLoaderTests.cs @@ -228,7 +228,7 @@ public void CanSuccessfullyTrimSpaces() public void ThrowsExceptionWithPropertyName() { Exception ex = Assert.Throws( () => new Data.TextLoader("fakefile.txt").CreateFrom() ); - Assert.StartsWith("String1 is missing ColumnAttribute", ex.Message); + Assert.StartsWith("Field or property String1 is missing ColumnAttribute", ex.Message); } public class QuoteInput diff --git a/test/run-tests.proj b/test/run-tests.proj index dd2433b3c5..a5afe75dd3 100644 --- a/test/run-tests.proj +++ b/test/run-tests.proj @@ -3,6 +3,7 @@ +