From fb536fd712fc2299f1e8c155e9f5a29ca67dddbf Mon Sep 17 00:00:00 2001 From: Don Syme Date: Mon, 30 Jul 2018 16:12:04 +0100 Subject: [PATCH 01/15] add F# smoke test --- Microsoft.ML.sln | 11 ++ .../Microsoft.ML.FSharp.Tests.fsproj | 53 +++++++ test/Microsoft.ML.FSharp.Tests/SmokeTests.fs | 132 ++++++++++++++++++ 3 files changed, 196 insertions(+) create mode 100644 test/Microsoft.ML.FSharp.Tests/Microsoft.ML.FSharp.Tests.fsproj create mode 100644 test/Microsoft.ML.FSharp.Tests/SmokeTests.fs diff --git a/Microsoft.ML.sln b/Microsoft.ML.sln index 58e24041f1..cdf43c7794 100644 --- a/Microsoft.ML.sln +++ b/Microsoft.ML.sln @@ -97,6 +97,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CodeAnalyzer", EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CodeAnalyzer.Tests", "test\Microsoft.ML.CodeAnalyzer.Tests\Microsoft.ML.CodeAnalyzer.Tests.csproj", "{3E4ABF07-7970-4BE6-B45B-A13D3C397545}" EndProject +Project("{F2A71F9B-5D33-465A-A702-920D77279786}") = "Microsoft.ML.FSharp.Tests", "test\Microsoft.ML.FSharp.Tests\Microsoft.ML.FSharp.Tests.fsproj", "{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -329,6 +331,14 @@ Global {3E4ABF07-7970-4BE6-B45B-A13D3C397545}.Release|Any CPU.Build.0 = Release|Any CPU {3E4ABF07-7970-4BE6-B45B-A13D3C397545}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU {3E4ABF07-7970-4BE6-B45B-A13D3C397545}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU + {802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Debug|Any CPU.Build.0 = Debug|Any CPU + {802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU + {802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU + {802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Release|Any CPU.ActiveCfg = Release|Any CPU + {802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Release|Any CPU.Build.0 = Release|Any CPU + {802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU + {802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -367,6 +377,7 @@ Global {BF66A305-DF10-47E4-8D81-42049B149D2B} = {D3D38B03-B557-484D-8348-8BADEE4DF592} {B4E55B2D-2A92-46E7-B72F-E76D6FD83440} = {7F13E156-3EBA-4021-84A5-CD56BA72F99E} {3E4ABF07-7970-4BE6-B45B-A13D3C397545} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} + {802233D6-8CC0-46AD-9F23-FEE1E9AED9B3} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D} diff --git a/test/Microsoft.ML.FSharp.Tests/Microsoft.ML.FSharp.Tests.fsproj b/test/Microsoft.ML.FSharp.Tests/Microsoft.ML.FSharp.Tests.fsproj new file mode 100644 index 0000000000..0dd1dc2a40 --- /dev/null +++ b/test/Microsoft.ML.FSharp.Tests/Microsoft.ML.FSharp.Tests.fsproj @@ -0,0 +1,53 @@ + + + + netcoreapp2.0 + 2003;$(NoWarn) + $(TargetFrameworks); net461 + false + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs new file mode 100644 index 0000000000..c6c05d9714 --- /dev/null +++ b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs @@ -0,0 +1,132 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + + +//================================================================================================= +// This test can be run either as a compiled test with .NET Core (on any platform) or +// manually in script form (to help debug it and also check that F# scripting works with ML.NET). +// Running as a script requires using F# Interactive on Windows, and the explicit references below. +// The references would normally be created by a package loader for the scripting +// environment, e.g. see https://github.com/isaacabraham/ml-test-experiment/, but +// here we list them explicitly to avoid the dependency on a package loader, +// +// You should build Microsoft.ML.FSharp.Tests in Debug mode for framework net461 +// before running this as a script with F# Interactive by editing the project +// file to have: +// netcoreapp2.0; net461 + +#if INTERACTIVE +#r "netstandard" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.Core.dll" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Google.Protobuf.dll" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Newtonsoft.Json.dll" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/System.CodeDom.dll" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/System.Threading.Tasks.Dataflow.dll" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.CpuMath.dll" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.Data.dll" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.Transforms.dll" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.ResultProcessor.dll" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.PCA.dll" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.KMeansClustering.dll" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.FastTree.dll" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.Api.dll" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.Sweeper.dll" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.dll" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.StandardLearners.dll" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/Microsoft.ML.PipelineInference.dll" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/xunit.core.dll" +#r @"../../bin/AnyCPU.Debug/Microsoft.ML.FSharp.Tests/net461/xunit.assert.dll" +#r "System" +#r "System.ComponentModel.Composition" +#r "System.Core" +#r "System.Xml.Linq" + +// Later tests will add data import using F# type providers: +//#r @"../../packages/fsharp.data/3.0.0-beta4/lib/netstandard2.0/FSharp.Data.dll" // this must be referenced from its package location + +let _load = + // See https://github.com/dotnet/machinelearning/issues/401: forces the loading of ML.NET assemblies + [ typeof; + typeof ] + +#endif + +//================================================================================ +// The tests proper start here + +#if !INTERACTIVE +namespace Microsoft.ML.FSharp.Tests +#endif + +open System +open Microsoft.ML +open Microsoft.ML.Data +open Microsoft.ML.Transforms +open Microsoft.ML.Trainers +open Microsoft.ML.Runtime.Api +open Xunit + +module SmokeTest1 = + + type SentimentData() = + [] + val mutable SentimentText : string + [] + val mutable Sentiment : float32 + + type SentimentPrediction() = + [] + val mutable Sentiment : bool + + [] + let ``FSharp-Sentiment-Smoke-Test`` () = + + let testDataPath = __SOURCE_DIRECTORY__ + @"/../data/wikipedia-detox-250-line-data.tsv" + + let pipeline = LearningPipeline() + + pipeline.Add( + TextLoader(testDataPath).CreateFrom( + Arguments = + TextLoaderArguments( + HasHeader = true, + Column = [| TextLoaderColumn(Name = "Label", Source = [| TextLoaderRange(0) |], Type = Nullable (Data.DataKind.Num)) + TextLoaderColumn(Name = "SentimentText", Source = [| TextLoaderRange(1) |], Type = Nullable (Data.DataKind.Text)) |] + ))) + + pipeline.Add( + TextFeaturizer( + "Features", [| "SentimentText" |], + KeepDiacritics = false, + KeepPunctuations = false, + TextCase = TextNormalizerTransformCaseNormalizationMode.Lower, + OutputTokens = true, + VectorNormalizer = TextTransformTextNormKind.L2 + )) + + pipeline.Add( + FastTreeBinaryClassifier( + NumLeaves = 5, + NumTrees = 5, + MinDocumentsInLeafs = 2 + )) + + let model = pipeline.Train() + + let predictions = + [ SentimentData(SentimentText = "This is a gross exaggeration. Nobody is setting a kangaroo court. There was a simple addition.") + SentimentData(SentimentText = "Sort of ok") + SentimentData(SentimentText = "Joe versus the Volcano Coffee Company is a great film.") ] + |> model.Predict + + let predictionResults = [ for p in predictions -> p.Sentiment ] + Assert.Equal(predictionResults, [ false; true; true ]) + +#if NETCOREAPP2_0 +module Program = + + [] + let main _ = 0 +#endif + From e9040e9f3bc2d27e8d2f734b2e6f02bdbbc2cd51 Mon Sep 17 00:00:00 2001 From: Don Syme Date: Mon, 30 Jul 2018 19:45:20 +0100 Subject: [PATCH 02/15] records --- test/Microsoft.ML.FSharp.Tests/SmokeTests.fs | 67 +++++++++++++++++++- 1 file changed, 65 insertions(+), 2 deletions(-) diff --git a/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs index c6c05d9714..1306314427 100644 --- a/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs +++ b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs @@ -91,8 +91,12 @@ module SmokeTest1 = Arguments = TextLoaderArguments( HasHeader = true, - Column = [| TextLoaderColumn(Name = "Label", Source = [| TextLoaderRange(0) |], Type = Nullable (Data.DataKind.Num)) - TextLoaderColumn(Name = "SentimentText", Source = [| TextLoaderRange(1) |], Type = Nullable (Data.DataKind.Text)) |] + Column = [| TextLoaderColumn(Name = "Label", + Source = [| TextLoaderRange(0) |], + Type = Nullable (Data.DataKind.Num)) + TextLoaderColumn(Name = "SentimentText", + Source = [| TextLoaderRange(1) |], + Type = Nullable (Data.DataKind.Text)) |] ))) pipeline.Add( @@ -123,6 +127,65 @@ module SmokeTest1 = let predictionResults = [ for p in predictions -> p.Sentiment ] Assert.Equal(predictionResults, [ false; true; true ]) +module SmokeTest2 = + + type SentimentData = + { [] SentimentText : string + [] Sentiment : float } + + [] + type SentimentPrediction = + { [] Sentiment : bool } + + [] + let ``FSharp-Sentiment-Smoke-Test`` () = + + let testDataPath = __SOURCE_DIRECTORY__ + @"/../data/wikipedia-detox-250-line-data.tsv" + + let pipeline = LearningPipeline() + + pipeline.Add( + TextLoader(testDataPath).CreateFrom( + Arguments = + TextLoaderArguments( + HasHeader = true, + Column = [| TextLoaderColumn(Name = "Label", + Source = [| TextLoaderRange(0) |], + Type = Nullable (Data.DataKind.Num)) + TextLoaderColumn(Name = "SentimentText", + Source = [| TextLoaderRange(1) |], + Type = Nullable (Data.DataKind.Text)) |] + ))) + + pipeline.Add( + TextFeaturizer( + "Features", [| "SentimentText" |], + KeepDiacritics = false, + KeepPunctuations = false, + TextCase = TextNormalizerTransformCaseNormalizationMode.Lower, + OutputTokens = true, + VectorNormalizer = TextTransformTextNormKind.L2 + )) + + pipeline.Add( + FastTreeBinaryClassifier( + NumLeaves = 5, + NumTrees = 5, + MinDocumentsInLeafs = 2 + )) + + let model = pipeline.Train() + + let predictions = + [ SentimentData(SentimentText = "This is a gross exaggeration. Nobody is setting a kangaroo court. There was a simple addition.") + SentimentData(SentimentText = "Sort of ok") + SentimentData(SentimentText = "Joe versus the Volcano Coffee Company is a great film.") ] + |> model.Predict + + let predictionResults = [ for p in predictions -> p.Sentiment ] + Assert.Equal(predictionResults, [ false; true; true ]) + + #if NETCOREAPP2_0 module Program = From 8ec2fd52129bc81b6e3b50b861442dd22be040b8 Mon Sep 17 00:00:00 2001 From: Don Syme Date: Mon, 30 Jul 2018 19:57:21 +0100 Subject: [PATCH 03/15] update for code review --- .../Microsoft.ML.FSharp.Tests.fsproj | 7 ++----- test/run-tests.proj | 1 + 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/test/Microsoft.ML.FSharp.Tests/Microsoft.ML.FSharp.Tests.fsproj b/test/Microsoft.ML.FSharp.Tests/Microsoft.ML.FSharp.Tests.fsproj index 0dd1dc2a40..196ec1b893 100644 --- a/test/Microsoft.ML.FSharp.Tests/Microsoft.ML.FSharp.Tests.fsproj +++ b/test/Microsoft.ML.FSharp.Tests/Microsoft.ML.FSharp.Tests.fsproj @@ -2,9 +2,10 @@ netcoreapp2.0 - 2003;$(NoWarn) $(TargetFrameworks); net461 + 2003;$(NoWarn) false + @@ -12,10 +13,6 @@ - - - - diff --git a/test/run-tests.proj b/test/run-tests.proj index dd2433b3c5..a5afe75dd3 100644 --- a/test/run-tests.proj +++ b/test/run-tests.proj @@ -3,6 +3,7 @@ + From 968766f3b8f9821c5657f4cd9968f93e5e745f8e Mon Sep 17 00:00:00 2001 From: Don Syme Date: Tue, 31 Jul 2018 13:44:31 +0100 Subject: [PATCH 04/15] attempt properties --- src/Microsoft.ML.Api/ApiUtils.cs | 96 +++++++++++++++---- .../DataViewConstructionUtils.cs | 2 +- .../InternalSchemaDefinition.cs | 67 +++++++------ src/Microsoft.ML.Api/SchemaDefinition.cs | 95 +++++++++++++----- src/Microsoft.ML.Api/TypedCursor.cs | 19 ++-- src/Microsoft.ML/Data/TextLoader.cs | 44 +++++++-- test/Microsoft.ML.FSharp.Tests/SmokeTests.fs | 15 +-- 7 files changed, 243 insertions(+), 95 deletions(-) diff --git a/src/Microsoft.ML.Api/ApiUtils.cs b/src/Microsoft.ML.Api/ApiUtils.cs index 8b8cb5871b..54846eb631 100644 --- a/src/Microsoft.ML.Api/ApiUtils.cs +++ b/src/Microsoft.ML.Api/ApiUtils.cs @@ -51,14 +51,30 @@ private static OpCode GetAssignmentOpCode(Type t) /// internal static Delegate GeneratePeek(InternalSchemaDefinition.Column column) { - var fieldInfo = column.FieldInfo; - Type fieldType = fieldInfo.FieldType; - - var assignmentOpCode = GetAssignmentOpCode(fieldType); - Func func = GeneratePeek; - var methInfo = func.GetMethodInfo().GetGenericMethodDefinition() - .MakeGenericMethod(typeof(TOwn), typeof(TRow), fieldType); - return (Delegate)methInfo.Invoke(null, new object[] { fieldInfo, assignmentOpCode }); + switch (column.MemberInfo) + { + case FieldInfo fieldInfo: + Type fieldType = fieldInfo.FieldType; + + var assignmentOpCode = GetAssignmentOpCode(fieldType); + Func func = GeneratePeek; + var methInfo = func.GetMethodInfo().GetGenericMethodDefinition() + .MakeGenericMethod(typeof(TOwn), typeof(TRow), fieldType); + return (Delegate)methInfo.Invoke(null, new object[] { fieldInfo, assignmentOpCode }); + + case PropertyInfo propertyInfo: + Type propertyType = propertyInfo.PropertyType; + + var assignmentOpCodeProp = GetAssignmentOpCode(propertyType); + Func funcProp = GeneratePeek; + var methInfoProp = funcProp.GetMethodInfo().GetGenericMethodDefinition() + .MakeGenericMethod(typeof(TOwn), typeof(TRow), propertyType); + return (Delegate)methInfoProp.Invoke(null, new object[] { propertyInfo, assignmentOpCodeProp }); + + default: + throw Contracts.ExceptNotSupp("expected a FieldInfo or a PropInfo"); + + } } private static Delegate GeneratePeek(FieldInfo fieldInfo, OpCode assignmentOpCode) @@ -81,6 +97,26 @@ private static Delegate GeneratePeek(FieldInfo fieldInfo, Op return mb.CreateDelegate(typeof(Peek)); } + private static Delegate GeneratePeek(PropertyInfo propertyInfo, OpCode assignmentOpCode) + { + // REVIEW: It seems like we really should cache these, instead of generating them per cursor. + Type[] args = { typeof(TOwn), typeof(TRow), typeof(long), typeof(TValue).MakeByRefType() }; + var mb = new DynamicMethod("Peek", null, args, typeof(TOwn), true); + var il = mb.GetILGenerator(); + + il.Emit(OpCodes.Ldarg_3); // push arg3 + il.Emit(OpCodes.Ldarg_1); // push arg1 + il.Emit(OpCodes.Call, propertyInfo.GetGetMethod()); // push [stack top].[propertyInfo] + // Stobj needs to coupled with a type. + if (assignmentOpCode == OpCodes.Stobj) // [stack top-1] = [stack top] + il.Emit(assignmentOpCode, propertyInfo.PropertyType); + else + il.Emit(assignmentOpCode); + il.Emit(OpCodes.Ret); // ret + + return mb.CreateDelegate(typeof(Peek)); + } + /// /// Each of the specialized 'poke' methods sets the appropriate field value of an instance of T /// to the provided value. So, the call is 'peek(userObject, providedValue)' and the logic is @@ -88,14 +124,29 @@ private static Delegate GeneratePeek(FieldInfo fieldInfo, Op /// internal static Delegate GeneratePoke(InternalSchemaDefinition.Column column) { - var fieldInfo = column.FieldInfo; - Type fieldType = fieldInfo.FieldType; - - var assignmentOpCode = GetAssignmentOpCode(fieldType); - Func func = GeneratePoke; - var methInfo = func.GetMethodInfo().GetGenericMethodDefinition() - .MakeGenericMethod(typeof(TOwn), typeof(TRow), fieldType); - return (Delegate)methInfo.Invoke(null, new object[] { fieldInfo, assignmentOpCode }); + switch (column.MemberInfo) + { + case FieldInfo fieldInfo: + Type fieldType = fieldInfo.FieldType; + + var assignmentOpCode = GetAssignmentOpCode(fieldType); + Func func = GeneratePoke; + var methInfo = func.GetMethodInfo().GetGenericMethodDefinition() + .MakeGenericMethod(typeof(TOwn), typeof(TRow), fieldType); + return (Delegate)methInfo.Invoke(null, new object[] { fieldInfo, assignmentOpCode }); + + case PropertyInfo propertyInfo: + Type propertyType = propertyInfo.PropertyType; + + var assignmentOpCodeProp = GetAssignmentOpCode(propertyType); + Func funcProp = GeneratePoke; + var methInfoProp = funcProp.GetMethodInfo().GetGenericMethodDefinition() + .MakeGenericMethod(typeof(TOwn), typeof(TRow), propertyType); + return (Delegate)methInfoProp.Invoke(null, new object[] { propertyInfo }); + + default: + throw Contracts.ExceptNotSupp("expected a FieldInfo or a PropInfo"); + } } private static Delegate GeneratePoke(FieldInfo fieldInfo, OpCode assignmentOpCode) @@ -115,5 +166,18 @@ private static Delegate GeneratePoke(FieldInfo fieldInfo, Op il.Emit(OpCodes.Ret); // ret return mb.CreateDelegate(typeof(Poke), null); } + + private static Delegate GeneratePoke(PropertyInfo propertyInfo) + { + Type[] args = { typeof(TOwn), typeof(TRow), typeof(TValue) }; + var mb = new DynamicMethod("Poke", null, args, typeof(TOwn), true); + var il = mb.GetILGenerator(); + + il.Emit(OpCodes.Ldarg_1); // push arg1 + il.Emit(OpCodes.Ldarg_2); // push arg2 + il.Emit(OpCodes.Call, propertyInfo.GetSetMethod()); // [stack top-1].[propertyInfo] <- [stack top] + il.Emit(OpCodes.Ret); // ret + return mb.CreateDelegate(typeof(Poke), null); + } } } diff --git a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs index 341e3a72af..b59fe64bb2 100644 --- a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs @@ -118,7 +118,7 @@ private Delegate CreateGetter(int index) var colType = DataView.Schema.GetColumnType(index); var column = DataView._schema.SchemaDefn.Columns[index]; - var outputType = column.IsComputed ? column.ReturnType : column.FieldInfo.FieldType; + var outputType = column.OutputType; var genericType = outputType; Func del; diff --git a/src/Microsoft.ML.Api/InternalSchemaDefinition.cs b/src/Microsoft.ML.Api/InternalSchemaDefinition.cs index 3edf7599a4..dcc804c1c4 100644 --- a/src/Microsoft.ML.Api/InternalSchemaDefinition.cs +++ b/src/Microsoft.ML.Api/InternalSchemaDefinition.cs @@ -23,21 +23,22 @@ internal sealed class InternalSchemaDefinition public class Column { public readonly string ColumnName; - public readonly FieldInfo FieldInfo; + public readonly MemberInfo MemberInfo; public readonly ParameterInfo ReturnParameterInfo; public readonly ColumnType ColumnType; public readonly bool IsComputed; public readonly Delegate Generator; private readonly Dictionary _metadata; public Dictionary Metadata { get { return _metadata; } } - public Type ReturnType {get { return ReturnParameterInfo.ParameterType.GetElementType(); }} + public Type ComputedReturnType {get { return ReturnParameterInfo.ParameterType.GetElementType(); }} + public Type OutputType => IsComputed ? ComputedReturnType : (MemberInfo is FieldInfo) ? (MemberInfo as FieldInfo).FieldType : (MemberInfo as PropertyInfo).PropertyType; - public Column(string columnName, ColumnType columnType, FieldInfo fieldInfo) : - this(columnName, columnType, fieldInfo, null, null) { } + public Column(string columnName, ColumnType columnType, MemberInfo memberInfo) : + this(columnName, columnType, memberInfo, null, null) { } - public Column(string columnName, ColumnType columnType, FieldInfo fieldInfo, + public Column(string columnName, ColumnType columnType, MemberInfo memberInfo, Dictionary metadataInfos) : - this(columnName, columnType, fieldInfo, null, metadataInfos) { } + this(columnName, columnType, memberInfo, null, metadataInfos) { } public Column(string columnName, ColumnType columnType, Delegate generator) : this(columnName, columnType, null, generator, null) { } @@ -46,7 +47,7 @@ public Column(string columnName, ColumnType columnType, Delegate generator, Dictionary metadataInfos) : this(columnName, columnType, null, generator, metadataInfos) { } - private Column(string columnName, ColumnType columnType, FieldInfo fieldInfo = null, + private Column(string columnName, ColumnType columnType, MemberInfo memberInfo = null, Delegate generator = null, Dictionary metadataInfos = null) { Contracts.AssertNonEmpty(columnName); @@ -55,8 +56,8 @@ private Column(string columnName, ColumnType columnType, FieldInfo fieldInfo = n if (generator == null) { - Contracts.AssertValue(fieldInfo); - FieldInfo = fieldInfo; + Contracts.AssertValue(memberInfo); + MemberInfo = memberInfo; } else { @@ -95,8 +96,8 @@ public void AssertRep() // If Column is computed type, it must have a generator. Contracts.Assert(IsComputed == (Generator != null)); - // Column must have either a generator or a fieldInfo value. - Contracts.Assert((Generator == null) != (FieldInfo == null)); + // Column must have either a generator or a memberInfo value. + Contracts.Assert((Generator == null) != (MemberInfo == null)); // Additional Checks if there is a generator. if (Generator == null) @@ -115,9 +116,7 @@ public void AssertRep() Contracts.Assert(Generator.GetMethodInfo().ReturnType == typeof(void)); // Checks that the return type of the generator is compatible with ColumnType. - bool isVector; - DataKind datakind; - GetVectorAndKind(ReturnType, "return type", out isVector, out datakind); + GetVectorAndKind(ComputedReturnType, "return type", out bool isVector, out DataKind datakind); Contracts.Assert(isVector == ColumnType.IsVector); Contracts.Assert(datakind == ColumnType.ItemType.RawKind); } @@ -131,19 +130,29 @@ private InternalSchemaDefinition(Column[] columns) } /// - /// Given a field info on a type, returns whether this appears to be a vector type, + /// Given a field or property info on a type, returns whether this appears to be a vector type, /// and also the associated data kind for this type. If a data kind could not /// be determined, this will throw. /// - /// The field info to inspect. + /// The field or property info to inspect. /// Whether this appears to be a vector type. /// The data kind of the type, or items of this type if vector. - public static void GetVectorAndKind(FieldInfo fieldInfo, out bool isVector, out DataKind kind) + public static void GetVectorAndKind(MemberInfo memberInfo, out bool isVector, out DataKind kind) { - Contracts.AssertValue(fieldInfo); - Type rawFieldType = fieldInfo.FieldType; - var name = fieldInfo.Name; - GetVectorAndKind(rawFieldType, name, out isVector, out kind); + Contracts.AssertValue(memberInfo); + switch (memberInfo) + { + case FieldInfo fieldInfo: + GetVectorAndKind(fieldInfo.FieldType, fieldInfo.Name, out isVector, out kind); + break; + + case PropertyInfo propertyInfo: + GetVectorAndKind(propertyInfo.PropertyType, propertyInfo.Name, out isVector, out kind); + break; + + default: + throw Contracts.ExceptNotSupp("expected a FieldInfo or a PropInfo"); + } } /// @@ -211,23 +220,27 @@ public static InternalSchemaDefinition Create(Type userType, SchemaDefinition us bool isVector; DataKind kind; - FieldInfo fieldInfo = null; + MemberInfo memberInfo = null; if (!col.IsComputed) { - fieldInfo = userType.GetField(col.MemberName); + memberInfo = userType.GetField(col.MemberName); + + if (memberInfo == null) + memberInfo = userType.GetProperty(col.MemberName); - if (fieldInfo == null) + if (memberInfo == null) throw Contracts.ExceptParam(nameof(userSchemaDefinition), "No field with name '{0}' found in type '{1}'", col.MemberName, userType.FullName); //Clause to handle the field that may be used to expose the cursor channel. //This field does not need a column. - if (fieldInfo.FieldType == typeof(IChannel)) + if ( (memberInfo is FieldInfo && (memberInfo as FieldInfo).FieldType == typeof(IChannel)) || + (memberInfo is PropertyInfo && (memberInfo as PropertyInfo).PropertyType == typeof(IChannel))) continue; - GetVectorAndKind(fieldInfo, out isVector, out kind); + GetVectorAndKind(memberInfo, out isVector, out kind); } else { @@ -268,7 +281,7 @@ public static InternalSchemaDefinition Create(Type userType, SchemaDefinition us dstCols[i] = col.IsComputed ? new Column(colName, colType, col.Generator, col.Metadata) - : new Column(colName, colType, fieldInfo, col.Metadata); + : new Column(colName, colType, memberInfo, col.Metadata); } return new InternalSchemaDefinition(dstCols); diff --git a/src/Microsoft.ML.Api/SchemaDefinition.cs b/src/Microsoft.ML.Api/SchemaDefinition.cs index e08845a87e..5a67b02b4c 100644 --- a/src/Microsoft.ML.Api/SchemaDefinition.cs +++ b/src/Microsoft.ML.Api/SchemaDefinition.cs @@ -158,19 +158,40 @@ public static bool TrySetCursorChannel(IExceptionContext ectx, T obj, IChanne .Where(x => x.GetCustomAttributes(typeof(CursorChannelAttribute), false).Any()) .ToArray(); + var cursorChannelAttrProperties = typeof(T) + .GetProperties(BindingFlags.Public | BindingFlags.Instance) + .Where(x => x.CanRead && x.CanWrite && x.GetIndexParameters().Length == 0) + .Where(x => x.GetCustomAttributes(typeof(CursorChannelAttribute), false).Any()) + .ToArray(); + + var cursorChannelAttrMembers = (cursorChannelAttrFields as IEnumerable).Concat(cursorChannelAttrProperties).ToArray(); + //Check that there is at most one such field. - if (cursorChannelAttrFields.Length == 0) + if (cursorChannelAttrMembers.Length == 0) return false; - ectx.Check(cursorChannelAttrFields.Length == 1, - "Only one field with CursorChannel attribute is allowed."); + ectx.Check(cursorChannelAttrMembers.Length == 1, + "Only one public field or property with CursorChannel attribute is allowed."); //Check that the marked field has type IChannel. - var cursorChannelFieldInfo = cursorChannelAttrFields[0]; - ectx.Check(cursorChannelFieldInfo.FieldType == typeof(IChannel), - "Field marked as CursorChannel must have type IChannel."); - - cursorChannelFieldInfo.SetValue(obj, channel); + var cursorChannelAttrMemberInfo = cursorChannelAttrMembers[0]; + switch (cursorChannelAttrMemberInfo) + { + case FieldInfo cursorChannelAttrFieldInfo: + ectx.Check(cursorChannelAttrFieldInfo.FieldType == typeof(IChannel), + "Field marked as CursorChannel must have type IChannel."); + cursorChannelAttrFieldInfo.SetValue(obj, channel); + break; + + case PropertyInfo cursorChannelAttrPropertyInfo: + ectx.Check(cursorChannelAttrPropertyInfo.PropertyType == typeof(IChannel), + "Property marked as CursorChannel must have type IChannel."); + cursorChannelAttrPropertyInfo.SetValue(obj, channel); + break; + + default: + throw Contracts.ExceptNotSupp("expected a FieldInfo or a PropInfo"); + } return true; } } @@ -319,37 +340,63 @@ public static SchemaDefinition Create(Type userType) SchemaDefinition cols = new SchemaDefinition(); HashSet colNames = new HashSet(); - foreach (var fieldInfo in userType.GetFields()) + + var fieldInfos = userType.GetFields(); + var propertyInfos = + userType + .GetProperties(BindingFlags.Public | BindingFlags.Instance) + .Where(x => x.CanRead && x.CanWrite && x.GetIndexParameters().Length == 0) + .ToArray(); + + var memberInfos = (fieldInfos as IEnumerable).Concat(propertyInfos).ToArray(); + + foreach (var memberInfo in memberInfos) { // Clause to handle the field that may be used to expose the cursor channel. // This field does not need a column. // REVIEW: maybe validate the channel attribute now, instead // of later at cursor creation. - if (fieldInfo.FieldType == typeof(IChannel)) - continue; - // Const fields do not need to be mapped. - if (fieldInfo.IsLiteral) - continue; + switch (memberInfo) + { + case FieldInfo fieldInfo: + if (fieldInfo.FieldType == typeof(IChannel)) + continue; + + // Const fields do not need to be mapped. + if (fieldInfo.IsLiteral) + continue; - if (fieldInfo.GetCustomAttribute() != null) + break; + + case PropertyInfo propertyInfo: + if (propertyInfo.PropertyType == typeof(IChannel)) + continue; + break; + + default: + throw Contracts.ExceptNotSupp("expected a FieldInfo or a PropInfo"); + } + + if (memberInfo.GetCustomAttribute() != null) continue; - var mappingAttr = fieldInfo.GetCustomAttribute(); - var mappingNameAttr = fieldInfo.GetCustomAttribute(); - string name = mappingAttr?.Name ?? mappingNameAttr?.Name ?? fieldInfo.Name; + + var mappingAttr = memberInfo.GetCustomAttribute(); + var mappingNameAttr = memberInfo.GetCustomAttribute(); + string name = mappingAttr?.Name ?? mappingNameAttr?.Name ?? memberInfo.Name; // Disallow duplicate names, because the field enumeration order is not actually // well defined, so we are not gauranteed to have consistent "hiding" from run to // run, across different .NET versions. if (!colNames.Add(name)) throw Contracts.ExceptParam(nameof(userType), "Duplicate column name '{0}' detected, this is disallowed", name); - InternalSchemaDefinition.GetVectorAndKind(fieldInfo, out bool isVector, out DataKind kind); + InternalSchemaDefinition.GetVectorAndKind(memberInfo, out bool isVector, out DataKind kind); PrimitiveType itemType; - var keyAttr = fieldInfo.GetCustomAttribute(); + var keyAttr = memberInfo.GetCustomAttribute(); if (keyAttr != null) { if (!KeyType.IsValidDataKind(kind)) - throw Contracts.ExceptParam(nameof(userType), "Member {0} marked with KeyType attribute, but does not appear to be a valid kind of data for a key type", fieldInfo.Name); + throw Contracts.ExceptParam(nameof(userType), "Member {0} marked with KeyType attribute, but does not appear to be a valid kind of data for a key type", memberInfo.Name); itemType = new KeyType(kind, keyAttr.Min, keyAttr.Count, keyAttr.Contiguous); } else @@ -357,9 +404,9 @@ public static SchemaDefinition Create(Type userType) // Get the column type. ColumnType columnType; - var vectorAttr = fieldInfo.GetCustomAttribute(); + var vectorAttr = memberInfo.GetCustomAttribute(); if (vectorAttr != null && !isVector) - throw Contracts.ExceptParam(nameof(userType), "Member {0} marked with VectorType attribute, but does not appear to be a vector type", fieldInfo.Name); + throw Contracts.ExceptParam(nameof(userType), "Member {0} marked with VectorType attribute, but does not appear to be a vector type", memberInfo.Name); if (isVector) { int[] dims = vectorAttr?.Dims; @@ -373,7 +420,7 @@ public static SchemaDefinition Create(Type userType) else columnType = itemType; - cols.Add(new Column() { MemberName = fieldInfo.Name, ColumnName = name, ColumnType = columnType }); + cols.Add(new Column() { MemberName = memberInfo.Name, ColumnName = name, ColumnType = columnType }); } return cols; } diff --git a/src/Microsoft.ML.Api/TypedCursor.cs b/src/Microsoft.ML.Api/TypedCursor.cs index f6ebaf687f..ee7fb23143 100644 --- a/src/Microsoft.ML.Api/TypedCursor.cs +++ b/src/Microsoft.ML.Api/TypedCursor.cs @@ -103,11 +103,11 @@ private TypedCursorable(IHostEnvironment env, IDataView data, bool ignoreMissing throw _host.Except("Column '{0}' not found in the data view", col.ColumnName); } var realColType = _data.Schema.GetColumnType(colIndex); - if (!IsCompatibleType(realColType, col.FieldInfo)) + if (!IsCompatibleType(realColType, col.MemberInfo)) { throw _host.Except( - "Can't bind the IDataView column '{0}' of type '{1}' to field '{2}' of type '{3}'.", - col.ColumnName, realColType, col.FieldInfo.Name, col.FieldInfo.FieldType.FullName); + "Can't bind the IDataView column '{0}' of type '{1}' to field or property '{2}' of type '{3}'.", + col.ColumnName, realColType, col.MemberInfo.Name, col.OutputType.FullName); } acceptedCols.Add(col); @@ -130,14 +130,12 @@ private TypedCursorable(IHostEnvironment env, IDataView data, bool ignoreMissing } /// - /// Returns whether the column type can be bound to field . + /// Returns whether the column type can be bound to field . /// They must both be vectors or scalars, and the raw data kind should match. /// - private static bool IsCompatibleType(ColumnType colType, FieldInfo fieldInfo) + private static bool IsCompatibleType(ColumnType colType, MemberInfo memberInfo) { - bool isVector; - DataKind kind; - InternalSchemaDefinition.GetVectorAndKind(fieldInfo, out isVector, out kind); + InternalSchemaDefinition.GetVectorAndKind(memberInfo, out bool isVector, out DataKind kind); if (isVector) return colType.IsVector && colType.ItemType.RawKind == kind; else @@ -269,8 +267,7 @@ public ValueGetter GetIdGetter() private Action GenerateSetter(IRow input, int index, InternalSchemaDefinition.Column column, Delegate poke, Delegate peek) { var colType = input.Schema.GetColumnType(index); - var fieldInfo = column.FieldInfo; - var fieldType = fieldInfo.FieldType; + var fieldType = column.OutputType; var genericType = fieldType; Func> del; if (fieldType.IsArray) @@ -430,7 +427,7 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit else { // REVIEW: Is this even possible? - throw Ch.ExceptNotImpl("Type '{0}' is not yet supported.", fieldInfo.FieldType.FullName); + throw Ch.ExceptNotImpl("Type '{0}' is not yet supported.", column.OutputType.FullName); } MethodInfo meth = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(genericType); return (Action)meth.Invoke(this, new object[] { input, index, poke, peek }); diff --git a/src/Microsoft.ML/Data/TextLoader.cs b/src/Microsoft.ML/Data/TextLoader.cs index 6e89e8a54e..1301ed6896 100644 --- a/src/Microsoft.ML/Data/TextLoader.cs +++ b/src/Microsoft.ML/Data/TextLoader.cs @@ -7,6 +7,7 @@ using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; using System; +using System.Collections.Generic; using System.Linq; using System.Reflection; using System.Text.RegularExpressions; @@ -71,20 +72,31 @@ public TextLoader CreateFrom(bool useHeader = false, char separator = '\t', bool allowQuotedStrings = true, bool supportSparse = true, bool trimWhitespace = false) { - var fields = typeof(TInput).GetFields(); - Arguments.Column = new TextLoaderColumn[fields.Length]; - for (int index = 0; index < fields.Length; index++) + var userType = typeof(TInput); + + var fieldInfos = userType.GetFields(); + + var propertyInfos = + userType + .GetProperties(BindingFlags.Public | BindingFlags.Instance) + .Where(x => x.CanRead && x.CanWrite && x.GetIndexParameters().Length == 0) + .ToArray(); + + var memberInfos = (fieldInfos as IEnumerable).Concat(propertyInfos).ToArray(); + + Arguments.Column = new TextLoaderColumn[memberInfos.Length]; + for (int index = 0; index < memberInfos.Length; index++) { - var field = fields[index]; - var mappingAttr = field.GetCustomAttribute(); + var memberInfo = memberInfos[index]; + var mappingAttr = memberInfo.GetCustomAttribute(); if (mappingAttr == null) - throw Contracts.Except($"{field.Name} is missing ColumnAttribute"); + throw Contracts.Except($"field or property {memberInfo.Name} is missing ColumnAttribute"); if (Regex.Match(mappingAttr.Ordinal, @"[^(0-9,\*\-~)]+").Success) throw Contracts.Except($"{mappingAttr.Ordinal} contains invalid characters. " + $"Valid characters are 0-9, *, - and ~"); - var name = mappingAttr.Name ?? field.Name; + var name = mappingAttr.Name ?? memberInfo.Name; Runtime.Data.TextLoader.Range[] sources; if (!Runtime.Data.TextLoader.Column.TryParseSourceEx(mappingAttr.Ordinal, out sources)) @@ -96,8 +108,22 @@ public TextLoader CreateFrom(bool useHeader = false, tlc.Name = name; tlc.Source = new TextLoaderRange[sources.Length]; DataKind dk; - if (!TryGetDataKind(field.FieldType.IsArray ? field.FieldType.GetElementType() : field.FieldType, out dk)) - throw Contracts.Except($"{name} is of unsupported type."); + switch (memberInfo) + { + case FieldInfo field: + if (!TryGetDataKind(field.FieldType.IsArray ? field.FieldType.GetElementType() : field.FieldType, out dk)) + throw Contracts.Except($"field {name} is of unsupported type."); + + break; + + case PropertyInfo property: + if (!TryGetDataKind(property.PropertyType.IsArray ? property.PropertyType.GetElementType() : property.PropertyType, out dk)) + throw Contracts.Except($"property {name} is of unsupported type."); + break; + + default: + throw Contracts.ExceptNotSupp("expected a FieldInfo or a PropInfo"); + } tlc.Type = dk; diff --git a/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs index 1306314427..ab1f4a6f9a 100644 --- a/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs +++ b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs @@ -76,7 +76,7 @@ module SmokeTest1 = val mutable Sentiment : float32 type SentimentPrediction() = - [] + [] val mutable Sentiment : bool [] @@ -129,13 +129,14 @@ module SmokeTest1 = module SmokeTest2 = + [] type SentimentData = - { [] SentimentText : string - [] Sentiment : float } + { [] SentimentText : string + [] Sentiment : float32 } [] type SentimentPrediction = - { [] Sentiment : bool } + { [] Sentiment : bool } [] let ``FSharp-Sentiment-Smoke-Test`` () = @@ -177,9 +178,9 @@ module SmokeTest2 = let model = pipeline.Train() let predictions = - [ SentimentData(SentimentText = "This is a gross exaggeration. Nobody is setting a kangaroo court. There was a simple addition.") - SentimentData(SentimentText = "Sort of ok") - SentimentData(SentimentText = "Joe versus the Volcano Coffee Company is a great film.") ] + [ { SentimentText = "This is a gross exaggeration. Nobody is setting a kangaroo court. There was a simple addition."; Sentiment = 0.0f } + { SentimentText = "Sort of ok"; Sentiment = 0.0f } + { SentimentText = "Joe versus the Volcano Coffee Company is a great film."; Sentiment = 0.0f } ] |> model.Predict let predictionResults = [ for p in predictions -> p.Sentiment ] From 9bea7e83ab0715ef52b4688f7bdc08e2486745ea Mon Sep 17 00:00:00 2001 From: Don Syme Date: Tue, 31 Jul 2018 14:00:04 +0100 Subject: [PATCH 05/15] allow attributes to be used on properties --- src/Microsoft.ML.Api/SchemaDefinition.cs | 12 ++++++------ test/Microsoft.ML.FSharp.Tests/SmokeTests.fs | 10 +++++++--- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/Microsoft.ML.Api/SchemaDefinition.cs b/src/Microsoft.ML.Api/SchemaDefinition.cs index 5a67b02b4c..22e02e0fa3 100644 --- a/src/Microsoft.ML.Api/SchemaDefinition.cs +++ b/src/Microsoft.ML.Api/SchemaDefinition.cs @@ -14,7 +14,7 @@ namespace Microsoft.ML.Runtime.Api /// /// Attach to a member of a class to indicate that the item type should be of class key. /// - [AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = true)] + [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)] public sealed class KeyTypeAttribute : Attribute { // REVIEW: Property based, but should I just have a constructor? @@ -46,7 +46,7 @@ public KeyTypeAttribute() /// Allows a member to be marked as a vector valued field, primarily allowing one to set /// the dimensionality of the resulting array. /// - [AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = true)] + [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)] public sealed class VectorTypeAttribute : Attribute { private readonly int[] _dims; @@ -66,7 +66,7 @@ public VectorTypeAttribute(params int[] dims) /// Describes column information such as name and the source columns indicies that this /// column encapsulates. /// - [AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = true)] + [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)] public sealed class ColumnAttribute : Attribute { public ColumnAttribute(string ordinal, string name = null) @@ -97,7 +97,7 @@ public ColumnAttribute(string ordinal, string name = null) /// Allows a member to specify its column name directly, as opposed to the default /// behavior of using the member name as the column name. /// - [AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = true)] + [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)] public sealed class ColumnNameAttribute : Attribute { private readonly string _name; @@ -119,7 +119,7 @@ public ColumnNameAttribute(string name) /// /// Mark this member as not being exposed as a column in the schema. /// - [AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = true)] + [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)] public sealed class NoColumnAttribute : Attribute { } @@ -128,7 +128,7 @@ public sealed class NoColumnAttribute : Attribute /// Mark a member that implements exactly IChannel as being permitted to receive /// channel information from an external channel. /// - [AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = true)] + [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)] public sealed class CursorChannelAttribute : Attribute { /// diff --git a/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs index ab1f4a6f9a..4f19dccabe 100644 --- a/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs +++ b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs @@ -131,12 +131,16 @@ module SmokeTest2 = [] type SentimentData = - { [] SentimentText : string - [] Sentiment : float32 } + { [] + SentimentText : string + + [] + Sentiment : float32 } [] type SentimentPrediction = - { [] Sentiment : bool } + { [] + Sentiment : bool } [] let ``FSharp-Sentiment-Smoke-Test`` () = From 382906e3d0d4938d0173090f058cad71c86dac5e Mon Sep 17 00:00:00 2001 From: Don Syme Date: Tue, 31 Jul 2018 14:44:36 +0100 Subject: [PATCH 06/15] update test case --- test/Microsoft.ML.Tests/TextLoaderTests.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Microsoft.ML.Tests/TextLoaderTests.cs b/test/Microsoft.ML.Tests/TextLoaderTests.cs index 61a1744dfb..9c853ffceb 100644 --- a/test/Microsoft.ML.Tests/TextLoaderTests.cs +++ b/test/Microsoft.ML.Tests/TextLoaderTests.cs @@ -228,7 +228,7 @@ public void CanSuccessfullyTrimSpaces() public void ThrowsExceptionWithPropertyName() { Exception ex = Assert.Throws( () => new Data.TextLoader("fakefile.txt").CreateFrom() ); - Assert.StartsWith("String1 is missing ColumnAttribute", ex.Message); + Assert.StartsWith("field or property String1 is missing ColumnAttribute", ex.Message); } public class QuoteInput From 4e131b5e4fadf7c9587707cbade211b70d0f6b8e Mon Sep 17 00:00:00 2001 From: Don Syme Date: Tue, 31 Jul 2018 14:52:25 +0100 Subject: [PATCH 07/15] force a static reference to Microsoft.ML.Transforms --- test/Microsoft.ML.FSharp.Tests/SmokeTests.fs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs index c6c05d9714..b4f333aceb 100644 --- a/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs +++ b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs @@ -79,6 +79,12 @@ module SmokeTest1 = [] val mutable Sentiment : bool + let _load = + // See https://github.com/dotnet/machinelearning/issues/401: forces the loading of ML.NET assemblies + // This is needed even for compiled code + [ typeof; + typeof ] + [] let ``FSharp-Sentiment-Smoke-Test`` () = From 31eefbc22d8c354bb976714ae11941c6001206d3 Mon Sep 17 00:00:00 2001 From: Don Syme Date: Tue, 31 Jul 2018 22:14:23 +0100 Subject: [PATCH 08/15] fix failing test --- .../Microsoft.ML.FSharp.Tests.fsproj | 2 ++ test/Microsoft.ML.FSharp.Tests/Program.fs | 9 +++++++ test/Microsoft.ML.FSharp.Tests/SmokeTests.fs | 24 +++++-------------- 3 files changed, 17 insertions(+), 18 deletions(-) create mode 100644 test/Microsoft.ML.FSharp.Tests/Program.fs diff --git a/test/Microsoft.ML.FSharp.Tests/Microsoft.ML.FSharp.Tests.fsproj b/test/Microsoft.ML.FSharp.Tests/Microsoft.ML.FSharp.Tests.fsproj index 196ec1b893..888a983e51 100644 --- a/test/Microsoft.ML.FSharp.Tests/Microsoft.ML.FSharp.Tests.fsproj +++ b/test/Microsoft.ML.FSharp.Tests/Microsoft.ML.FSharp.Tests.fsproj @@ -6,10 +6,12 @@ 2003;$(NoWarn) false + x64 + diff --git a/test/Microsoft.ML.FSharp.Tests/Program.fs b/test/Microsoft.ML.FSharp.Tests/Program.fs new file mode 100644 index 0000000000..f45e4e3c6c --- /dev/null +++ b/test/Microsoft.ML.FSharp.Tests/Program.fs @@ -0,0 +1,9 @@ +namespace Microsoft.ML.FSharp.Tests + +#if NETCOREAPP2_0 +module Program = + + [] + let main _ = 0 +#endif + diff --git a/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs index b4f333aceb..ea99a88010 100644 --- a/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs +++ b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs @@ -45,11 +45,6 @@ // Later tests will add data import using F# type providers: //#r @"../../packages/fsharp.data/3.0.0-beta4/lib/netstandard2.0/FSharp.Data.dll" // this must be referenced from its package location -let _load = - // See https://github.com/dotnet/machinelearning/issues/401: forces the loading of ML.NET assemblies - [ typeof; - typeof ] - #endif //================================================================================ @@ -79,15 +74,15 @@ module SmokeTest1 = [] val mutable Sentiment : bool - let _load = - // See https://github.com/dotnet/machinelearning/issues/401: forces the loading of ML.NET assemblies - // This is needed even for compiled code - [ typeof; - typeof ] - [] let ``FSharp-Sentiment-Smoke-Test`` () = + // See https://github.com/dotnet/machinelearning/issues/401: forces the loading of ML.NET component assemblies + + let _load = + [ typeof; + typeof ] + let testDataPath = __SOURCE_DIRECTORY__ + @"/../data/wikipedia-detox-250-line-data.tsv" let pipeline = LearningPipeline() @@ -129,10 +124,3 @@ module SmokeTest1 = let predictionResults = [ for p in predictions -> p.Sentiment ] Assert.Equal(predictionResults, [ false; true; true ]) -#if NETCOREAPP2_0 -module Program = - - [] - let main _ = 0 -#endif - From 47db808d006152dc46e5f58a36b9942073fc421b Mon Sep 17 00:00:00 2001 From: Don Syme Date: Tue, 31 Jul 2018 22:31:58 +0100 Subject: [PATCH 09/15] add some C# testing --- src/Microsoft.ML.Api/ApiUtils.cs | 8 +- .../InternalSchemaDefinition.cs | 3 +- src/Microsoft.ML.Api/TypedCursor.cs | 2 +- .../CollectionDataSourceTests.cs | 144 ++++++++++++++++++ 4 files changed, 153 insertions(+), 4 deletions(-) diff --git a/src/Microsoft.ML.Api/ApiUtils.cs b/src/Microsoft.ML.Api/ApiUtils.cs index 54846eb631..e485dd2da3 100644 --- a/src/Microsoft.ML.Api/ApiUtils.cs +++ b/src/Microsoft.ML.Api/ApiUtils.cs @@ -103,10 +103,12 @@ private static Delegate GeneratePeek(PropertyInfo propertyIn Type[] args = { typeof(TOwn), typeof(TRow), typeof(long), typeof(TValue).MakeByRefType() }; var mb = new DynamicMethod("Peek", null, args, typeof(TOwn), true); var il = mb.GetILGenerator(); + var minfo = propertyInfo.GetGetMethod(); + var opcode = minfo.IsVirtual ? OpCodes.Callvirt : OpCodes.Call; il.Emit(OpCodes.Ldarg_3); // push arg3 il.Emit(OpCodes.Ldarg_1); // push arg1 - il.Emit(OpCodes.Call, propertyInfo.GetGetMethod()); // push [stack top].[propertyInfo] + il.Emit(opcode, minfo); // push [stack top].[propertyInfo] // Stobj needs to coupled with a type. if (assignmentOpCode == OpCodes.Stobj) // [stack top-1] = [stack top] il.Emit(assignmentOpCode, propertyInfo.PropertyType); @@ -172,10 +174,12 @@ private static Delegate GeneratePoke(PropertyInfo propertyIn Type[] args = { typeof(TOwn), typeof(TRow), typeof(TValue) }; var mb = new DynamicMethod("Poke", null, args, typeof(TOwn), true); var il = mb.GetILGenerator(); + var minfo = propertyInfo.GetSetMethod(); + var opcode = minfo.IsVirtual ? OpCodes.Callvirt : OpCodes.Call; il.Emit(OpCodes.Ldarg_1); // push arg1 il.Emit(OpCodes.Ldarg_2); // push arg2 - il.Emit(OpCodes.Call, propertyInfo.GetSetMethod()); // [stack top-1].[propertyInfo] <- [stack top] + il.Emit(opcode, minfo); // [stack top-1].[propertyInfo] <- [stack top] il.Emit(OpCodes.Ret); // ret return mb.CreateDelegate(typeof(Poke), null); } diff --git a/src/Microsoft.ML.Api/InternalSchemaDefinition.cs b/src/Microsoft.ML.Api/InternalSchemaDefinition.cs index dcc804c1c4..863f4073b2 100644 --- a/src/Microsoft.ML.Api/InternalSchemaDefinition.cs +++ b/src/Microsoft.ML.Api/InternalSchemaDefinition.cs @@ -31,7 +31,8 @@ public class Column private readonly Dictionary _metadata; public Dictionary Metadata { get { return _metadata; } } public Type ComputedReturnType {get { return ReturnParameterInfo.ParameterType.GetElementType(); }} - public Type OutputType => IsComputed ? ComputedReturnType : (MemberInfo is FieldInfo) ? (MemberInfo as FieldInfo).FieldType : (MemberInfo as PropertyInfo).PropertyType; + public Type FieldOrPropertyType => (MemberInfo is FieldInfo) ? (MemberInfo as FieldInfo).FieldType : (MemberInfo as PropertyInfo).PropertyType; + public Type OutputType => IsComputed ? ComputedReturnType : FieldOrPropertyType; public Column(string columnName, ColumnType columnType, MemberInfo memberInfo) : this(columnName, columnType, memberInfo, null, null) { } diff --git a/src/Microsoft.ML.Api/TypedCursor.cs b/src/Microsoft.ML.Api/TypedCursor.cs index ee7fb23143..5fd0d97a4d 100644 --- a/src/Microsoft.ML.Api/TypedCursor.cs +++ b/src/Microsoft.ML.Api/TypedCursor.cs @@ -107,7 +107,7 @@ private TypedCursorable(IHostEnvironment env, IDataView data, bool ignoreMissing { throw _host.Except( "Can't bind the IDataView column '{0}' of type '{1}' to field or property '{2}' of type '{3}'.", - col.ColumnName, realColType, col.MemberInfo.Name, col.OutputType.FullName); + col.ColumnName, realColType, col.MemberInfo.Name, col.FieldOrPropertyType.FullName); } acceptedCols.Add(col); diff --git a/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs b/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs index 87e23952d6..f968b788de 100644 --- a/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs +++ b/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs @@ -275,6 +275,26 @@ public bool CompareThroughReflection(T x, T y) return true; } + public bool CompareThroughReflectionProperties(T x, T y) + { + foreach (var property in typeof(T).GetProperties()) + { + var xvalue = property.GetValue(x); + var yvalue = property.GetValue(y); + if (property.PropertyType.IsArray) + { + if (!CompareArrayValues(xvalue as Array, yvalue as Array)) + return false; + } + else + { + if (!CompareObjectValues(xvalue, yvalue, property.PropertyType)) + return false; + } + } + return true; + } + public bool CompareArrayValues(Array x, Array y) { if (x == null && y == null) return true; @@ -609,5 +629,129 @@ public void RoundTripConversionWithArrays() Assert.True(!enumeratorNullable.MoveNext() && !originalNullalbleEnumerator.MoveNext()); } } + public class ClassWithArrayProperties + { + private string[] _fString; + private int[] _fInt; + private uint[] _fuInt; + private short[] _fShort; + private ushort[] _fuShort; + private sbyte[] _fsByte; + private byte[] _fByte; + private long[] _fLong; + private ulong[] _fuLong; + private float[] _fFloat; + private double[] _fDouble; + private bool[] _fBool; + public string[] fString { get { return _fString; } set { _fString = value; } } + public int[] fInt { get { return _fInt; } set { _fInt = value; } } + public uint[] fuInt { get { return _fuInt; } set { _fuInt = value; } } + public short[] fShort { get { return _fShort; } set { _fShort = value; } } + public ushort[] fuShort { get { return _fuShort; } set { _fuShort = value; } } + public sbyte[] fsByte { get { return _fsByte; } set { _fsByte = value; } } + public byte[] fByte { get { return _fByte; } set { _fByte = value; } } + public long[] fLong { get { return _fLong; } set { _fLong = value; } } + public ulong[] fuLong { get { return _fuLong; } set { _fuLong = value; } } + public float[] fFloat { get { return _fFloat; } set { _fFloat = value; } } + public double[] fDouble { get { return _fDouble; } set { _fDouble = value; } } + public bool[] fBool { get { return _fBool; } set { _fBool = value; } } + } + + public class ClassWithNullableArrayProperties + { + private string[] _fString; + private int?[] _fInt; + private uint?[] _fuInt; + private short?[] _fShort; + private ushort?[] _fuShort; + private sbyte?[] _fsByte; + private byte?[] _fByte; + private long?[] _fLong; + private ulong?[] _fuLong; + private float?[] _fFloat; + private double?[] _fDouble; + private bool?[] _fBool; + + public string[] fString { get { return _fString; } set { _fString = value; } } + public int?[] fInt { get { return _fInt; } set { _fInt = value; } } + public uint?[] fuInt { get { return _fuInt; } set { _fuInt = value; } } + public short?[] fShort { get { return _fShort; } set { _fShort = value; } } + public ushort?[] fuShort { get { return _fuShort; } set { _fuShort = value; } } + public sbyte?[] fsByte { get { return _fsByte; } set { _fsByte = value; } } + public byte?[] fByte { get { return _fByte; } set { _fByte = value; } } + public long?[] fLong { get { return _fLong; } set { _fLong = value; } } + public ulong?[] fuLong { get { return _fuLong; } set { _fuLong = value; } } + public float?[] fFloat { get { return _fFloat; } set { _fFloat = value; } } + public double?[] fDouble { get { return _fDouble; } set { _fDouble = value; } } + public bool?[] fBool { get { return _fBool; } set { _fBool = value; } } + } + + [Fact] + public void RoundTripConversionWithArrayPropertiess() + { + + var data = new List + { + new ClassWithArrayProperties() + { + fInt = new int[3] { 0, 1, 2 }, + fFloat = new float[3] { -0.99f, 0f, 0.99f }, + fString = new string[2] { "hola", "lola" }, + fBool = new bool[2] { true, false }, + fByte = new byte[3] { 0, 124, 255 }, + fDouble = new double[3] { -1, 0, 1 }, + fLong = new long[] { 0, 1, 2 }, + fsByte = new sbyte[3] { -127, 127, 0 }, + fShort = new short[3] { 0, 1225, 32767 }, + fuInt = new uint[2] { 0, uint.MaxValue }, + fuLong = new ulong[2] { ulong.MaxValue, 0 }, + fuShort = new ushort[2] { 0, ushort.MaxValue } + }, + new ClassWithArrayProperties() { fInt = new int[3] { -2, 1, 0 }, fFloat = new float[3] { 0.99f, 0f, -0.99f }, fString = new string[2] { "", null } }, + new ClassWithArrayProperties() + }; + + var nullableData = new List + { + new ClassWithNullableArrayProperties() + { + fInt = new int?[3] { null, -1, 1 }, + fFloat = new float?[3] { -0.99f, null, 0.99f }, + fString = new string[2] { null, "" }, + fBool = new bool?[3] { true, null, false }, + fByte = new byte?[4] { 0, 125, null, 255 }, + fDouble = new double?[3] { -1, null, 1 }, + fLong = new long?[] { null, -1, 1 }, + fsByte = new sbyte?[3] { -127, 127, null }, + fShort = new short?[3] { 0, null, 32767 }, + fuInt = new uint?[4] { null, 42, 0, uint.MaxValue }, + fuLong = new ulong?[3] { ulong.MaxValue, null, 0 }, + fuShort = new ushort?[3] { 0, null, ushort.MaxValue } + }, + new ClassWithNullableArrayProperties() { fInt = new int?[3] { -2, 1, 0 }, fFloat = new float?[3] { 0.99f, 0f, -0.99f }, fString = new string[2] { "lola", "hola" } }, + new ClassWithNullableArrayProperties() + }; + + using (var env = new TlcEnvironment()) + { + var dataView = ComponentCreation.CreateDataView(env, data); + var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator(); + var originalEnumerator = data.GetEnumerator(); + while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext()) + { + Assert.True(CompareThroughReflectionProperties(enumeratorSimple.Current, originalEnumerator.Current)); + } + Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); + + var nullableDataView = ComponentCreation.CreateDataView(env, nullableData); + var enumeratorNullable = nullableDataView.AsEnumerable(env, false).GetEnumerator(); + var originalNullalbleEnumerator = nullableData.GetEnumerator(); + while (enumeratorNullable.MoveNext() && originalNullalbleEnumerator.MoveNext()) + { + Assert.True(CompareThroughReflectionProperties(enumeratorNullable.Current, originalNullalbleEnumerator.Current)); + } + Assert.True(!enumeratorNullable.MoveNext() && !originalNullalbleEnumerator.MoveNext()); + } + } } } From 747e4171f0d45c07ebdd4bc4716618e0b37b2479 Mon Sep 17 00:00:00 2001 From: Don Syme Date: Tue, 31 Jul 2018 23:00:00 +0100 Subject: [PATCH 10/15] add more C# tests --- .../CollectionDataSourceTests.cs | 201 ++++++++++++++++-- 1 file changed, 186 insertions(+), 15 deletions(-) diff --git a/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs b/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs index f968b788de..2de80b79c2 100644 --- a/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs +++ b/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs @@ -174,6 +174,49 @@ public void CanTrain() } + [Fact] + public void CanTrainProperties() + { + var pipeline = new LearningPipeline(); + var data = new List() { + new IrisDataProperties { SepalLength = 1f, SepalWidth = 1f, PetalLength=0.3f, PetalWidth=5.1f, Label=1}, + new IrisDataProperties { SepalLength = 1f, SepalWidth = 1f, PetalLength=0.3f, PetalWidth=5.1f, Label=1}, + new IrisDataProperties { SepalLength = 1.2f, SepalWidth = 0.5f, PetalLength=0.3f, PetalWidth=5.1f, Label=0} + }; + var collection = CollectionDataSource.Create(data); + + pipeline.Add(collection); + pipeline.Add(new ColumnConcatenator(outputColumn: "Features", + "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")); + pipeline.Add(new StochasticDualCoordinateAscentClassifier()); + PredictionModel model = pipeline.Train(); + + IrisPredictionProperties prediction = model.Predict(new IrisDataProperties() + { + SepalLength = 3.3f, + SepalWidth = 1.6f, + PetalLength = 0.2f, + PetalWidth = 5.1f, + }); + + pipeline = new LearningPipeline(); + collection = CollectionDataSource.Create(data.AsEnumerable()); + pipeline.Add(collection); + pipeline.Add(new ColumnConcatenator(outputColumn: "Features", + "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")); + pipeline.Add(new StochasticDualCoordinateAscentClassifier()); + model = pipeline.Train(); + + prediction = model.Predict(new IrisDataProperties() + { + SepalLength = 3.3f, + SepalWidth = 1.6f, + PetalLength = 0.2f, + PetalWidth = 5.1f, + }); + + } + public class Input { [Column("0")] @@ -207,6 +250,37 @@ public class IrisPrediction public float[] PredictedLabels; } + public class IrisDataProperties + { + private float _Label; + private float _SepalLength; + private float _SepalWidth; + private float _PetalLength; + private float _PetalWidth; + + [Column("0")] + public float Label { get { return _Label; } set { _Label = value; } } + + [Column("1")] + public float SepalLength { get { return _SepalLength; } set { _SepalLength = value; } } + + [Column("2")] + public float SepalWidth { get { return _SepalWidth; } set { _SepalWidth = value; } } + + [Column("3")] + public float PetalLength { get { return _PetalLength; } set { _PetalLength = value; } } + + [Column("4")] + public float PetalWidth { get { return _PetalWidth; } set { _PetalWidth = value; } } + } + + public class IrisPredictionProperties + { + private float[] _PredictedLabels; + [ColumnName("Score")] + public float[] PredictedLabels { get { return _PredictedLabels; } set { _PredictedLabels = value; } } + } + public class ConversionSimpleClass { public int fInt; @@ -272,11 +346,6 @@ public bool CompareThroughReflection(T x, T y) return false; } } - return true; - } - - public bool CompareThroughReflectionProperties(T x, T y) - { foreach (var property in typeof(T).GetProperties()) { var xvalue = property.GetValue(x); @@ -308,14 +377,6 @@ public bool CompareArrayValues(Array x, Array y) return true; } - public class ClassWithConstField - { - public const string ConstString = "N"; - public string fString; - public const int ConstInt = 100; - public int fInt; - } - [Fact] public void RoundTripConversionWithBasicTypes() { @@ -509,6 +570,50 @@ public void ConversionMinValueToNullBehavior() } } + public class ConversionLossMinValueClassProperties + { + private int? _fInt; + private long? _fLong; + private short? _fShort; + private sbyte? _fsByte; + public int? fInt { get { return _fInt; } set { _fInt = value; } } + public short? fShort { get { return _fShort; } set { _fShort = value; } } + public sbyte? fsByte { get { return _fsByte; } set { _fsByte = value; } } + public long? fLong { get { return _fLong; } set { _fLong = value; } } + } + + [Fact] + public void ConversionMinValueToNullBehaviorProperties() + { + using (var env = new TlcEnvironment()) + { + + var data = new List + { + new ConversionLossMinValueClassProperties() { fsByte = null, fInt = null, fLong = null, fShort = null }, + new ConversionLossMinValueClassProperties() { fsByte = sbyte.MinValue, fInt = int.MinValue, fLong = long.MinValue, fShort = short.MinValue } + }; + foreach (var field in typeof(ConversionLossMinValueClassProperties).GetFields()) + { + var dataView = ComponentCreation.CreateDataView(env, data); + var enumerator = dataView.AsEnumerable(env, false).GetEnumerator(); + while (enumerator.MoveNext()) + { + Assert.True(enumerator.Current.fInt == null && enumerator.Current.fLong == null && + enumerator.Current.fsByte == null && enumerator.Current.fShort == null); + } + } + } + } + + public class ClassWithConstField + { + public const string ConstString = "N"; + public string fString; + public const int ConstInt = 100; + public int fInt; + } + [Fact] public void ClassWithConstFieldsConversion() { @@ -530,6 +635,72 @@ public void ClassWithConstFieldsConversion() } } + + public class ClassWithMixOfFieldsAndProperties + { + public string fString; + private int _fInt; + public int fInt { get { return _fInt; } set { _fInt = value; } } + } + + [Fact] + public void ClassWithMixOfFieldsAndPropertiesConversion() + { + var data = new List() + { + new ClassWithMixOfFieldsAndProperties(){ fInt=1, fString ="lala" }, + new ClassWithMixOfFieldsAndProperties(){ fInt=-1, fString ="" }, + new ClassWithMixOfFieldsAndProperties(){ fInt=0, fString =null } + }; + + using (var env = new TlcEnvironment()) + { + var dataView = ComponentCreation.CreateDataView(env, data); + var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator(); + var originalEnumerator = data.GetEnumerator(); + while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext()) + Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current)); + Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); + } + } + + public abstract class BaseClassWithInheritedProperties + { + private string _fString; + public string fString { get { return _fString; } set { _fString = value; } } + public abstract long fLong { get; set; } + } + + + public class ClassWithInheritedProperties : BaseClassWithInheritedProperties + { + private int _fInt; + private long _fLong; + public int fInt { get { return _fInt; } set { _fInt = value; } } + public override long fLong { get { return _fLong; } set { _fLong = value; } } + } + + [Fact] + public void ClassWithInheritedPropertiesConversion() + { + var data = new List() + { + new ClassWithInheritedProperties(){ fInt=1, fString ="lala", fLong=17 }, + new ClassWithInheritedProperties(){ fInt=-1, fString ="", fLong=2 }, + new ClassWithInheritedProperties(){ fInt=0, fString =null, fLong=18 } + }; + + using (var env = new TlcEnvironment()) + { + var dataView = ComponentCreation.CreateDataView(env, data); + var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator(); + var originalEnumerator = data.GetEnumerator(); + while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext()) + Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current)); + Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); + } + } + public class ClassWithArrays { public string[] fString; @@ -739,7 +910,7 @@ public void RoundTripConversionWithArrayPropertiess() var originalEnumerator = data.GetEnumerator(); while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext()) { - Assert.True(CompareThroughReflectionProperties(enumeratorSimple.Current, originalEnumerator.Current)); + Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current)); } Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); @@ -748,7 +919,7 @@ public void RoundTripConversionWithArrayPropertiess() var originalNullalbleEnumerator = nullableData.GetEnumerator(); while (enumeratorNullable.MoveNext() && originalNullalbleEnumerator.MoveNext()) { - Assert.True(CompareThroughReflectionProperties(enumeratorNullable.Current, originalNullalbleEnumerator.Current)); + Assert.True(CompareThroughReflection(enumeratorNullable.Current, originalNullalbleEnumerator.Current)); } Assert.True(!enumeratorNullable.MoveNext() && !originalNullalbleEnumerator.MoveNext()); } From 710fe14798a0a7f3525113f1b2788ba653d32590 Mon Sep 17 00:00:00 2001 From: Don Syme Date: Tue, 31 Jul 2018 23:01:36 +0100 Subject: [PATCH 11/15] add IsAbstract --- src/Microsoft.ML.Api/ApiUtils.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.ML.Api/ApiUtils.cs b/src/Microsoft.ML.Api/ApiUtils.cs index 167ad8bb10..05a7b26f29 100644 --- a/src/Microsoft.ML.Api/ApiUtils.cs +++ b/src/Microsoft.ML.Api/ApiUtils.cs @@ -104,7 +104,7 @@ private static Delegate GeneratePeek(PropertyInfo propertyIn var mb = new DynamicMethod("Peek", null, args, typeof(TOwn), true); var il = mb.GetILGenerator(); var minfo = propertyInfo.GetGetMethod(); - var opcode = minfo.IsVirtual ? OpCodes.Callvirt : OpCodes.Call; + var opcode = (minfo.IsVirtual || minfo.IsAbstract) ? OpCodes.Callvirt : OpCodes.Call; il.Emit(OpCodes.Ldarg_3); // push arg3 il.Emit(OpCodes.Ldarg_1); // push arg1 @@ -175,7 +175,7 @@ private static Delegate GeneratePoke(PropertyInfo propertyIn var mb = new DynamicMethod("Poke", null, args, typeof(TOwn), true); var il = mb.GetILGenerator(); var minfo = propertyInfo.GetSetMethod(); - var opcode = minfo.IsVirtual ? OpCodes.Callvirt : OpCodes.Call; + var opcode = (minfo.IsVirtual || minfo.IsAbstract) ? OpCodes.Callvirt : OpCodes.Call; il.Emit(OpCodes.Ldarg_1); // push arg1 il.Emit(OpCodes.Ldarg_2); // push arg2 From 4e19c25f63ee0fd2c7b9a75e1842ef93181c6847 Mon Sep 17 00:00:00 2001 From: Don Syme Date: Wed, 1 Aug 2018 00:37:03 +0100 Subject: [PATCH 12/15] fix strings --- src/Microsoft.ML/Data/TextLoader.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.ML/Data/TextLoader.cs b/src/Microsoft.ML/Data/TextLoader.cs index a0681d707f..20b3d72619 100644 --- a/src/Microsoft.ML/Data/TextLoader.cs +++ b/src/Microsoft.ML/Data/TextLoader.cs @@ -111,13 +111,13 @@ public TextLoader CreateFrom(bool useHeader = false, { case FieldInfo field: if (!TryGetDataKind(field.FieldType.IsArray ? field.FieldType.GetElementType() : field.FieldType, out dk)) - throw Contracts.Except($"field {name} is of unsupported type."); + throw Contracts.Except($"Field {name} is of unsupported type."); break; case PropertyInfo property: if (!TryGetDataKind(property.PropertyType.IsArray ? property.PropertyType.GetElementType() : property.PropertyType, out dk)) - throw Contracts.Except($"property {name} is of unsupported type."); + throw Contracts.Except($"Property {name} is of unsupported type."); break; default: From 431ad89a84518d69000947058104dc8dfbb73dd4 Mon Sep 17 00:00:00 2001 From: Don Syme Date: Wed, 1 Aug 2018 19:46:25 +0100 Subject: [PATCH 13/15] add more tests --- src/Microsoft.ML.Api/ApiUtils.cs | 2 + .../InternalSchemaDefinition.cs | 1 + src/Microsoft.ML.Api/SchemaDefinition.cs | 8 +- src/Microsoft.ML/Data/TextLoader.cs | 5 +- .../CollectionDataSourceTests.cs | 196 +++++++++++------- 5 files changed, 136 insertions(+), 76 deletions(-) diff --git a/src/Microsoft.ML.Api/ApiUtils.cs b/src/Microsoft.ML.Api/ApiUtils.cs index 05a7b26f29..22afda7608 100644 --- a/src/Microsoft.ML.Api/ApiUtils.cs +++ b/src/Microsoft.ML.Api/ApiUtils.cs @@ -72,6 +72,7 @@ internal static Delegate GeneratePeek(InternalSchemaDefinition.Colum return (Delegate)methInfoProp.Invoke(null, new object[] { propertyInfo, assignmentOpCodeProp }); default: + Contracts.Assert(false); throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo"); } @@ -147,6 +148,7 @@ internal static Delegate GeneratePoke(InternalSchemaDefinition.Colum return (Delegate)methInfoProp.Invoke(null, new object[] { propertyInfo }); default: + Contracts.Assert(false); throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo"); } } diff --git a/src/Microsoft.ML.Api/InternalSchemaDefinition.cs b/src/Microsoft.ML.Api/InternalSchemaDefinition.cs index e3b30ec147..4c20f25d62 100644 --- a/src/Microsoft.ML.Api/InternalSchemaDefinition.cs +++ b/src/Microsoft.ML.Api/InternalSchemaDefinition.cs @@ -152,6 +152,7 @@ public static void GetVectorAndKind(MemberInfo memberInfo, out bool isVector, ou break; default: + Contracts.Assert(false); throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo"); } } diff --git a/src/Microsoft.ML.Api/SchemaDefinition.cs b/src/Microsoft.ML.Api/SchemaDefinition.cs index daeb414cd6..3258df4ffd 100644 --- a/src/Microsoft.ML.Api/SchemaDefinition.cs +++ b/src/Microsoft.ML.Api/SchemaDefinition.cs @@ -160,7 +160,7 @@ public static bool TrySetCursorChannel(IExceptionContext ectx, T obj, IChanne var cursorChannelAttrProperties = typeof(T) .GetProperties(BindingFlags.Public | BindingFlags.Instance) - .Where(x => x.CanRead && x.CanWrite && x.GetIndexParameters().Length == 0) + .Where(x => x.CanRead && x.CanWrite && x.GetGetMethod() != null && x.GetSetMethod() != null && x.GetIndexParameters().Length == 0) .Where(x => x.GetCustomAttributes(typeof(CursorChannelAttribute), false).Any()); var cursorChannelAttrMembers = (cursorChannelAttrFields as IEnumerable).Concat(cursorChannelAttrProperties).ToArray(); @@ -189,6 +189,7 @@ public static bool TrySetCursorChannel(IExceptionContext ectx, T obj, IChanne break; default: + Contracts.Assert(false); throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo"); } return true; @@ -340,11 +341,11 @@ public static SchemaDefinition Create(Type userType) SchemaDefinition cols = new SchemaDefinition(); HashSet colNames = new HashSet(); - var fieldInfos = userType.GetFields(); + var fieldInfos = userType.GetFields(BindingFlags.Public | BindingFlags.Instance); var propertyInfos = userType .GetProperties(BindingFlags.Public | BindingFlags.Instance) - .Where(x => x.CanRead && x.CanWrite && x.GetIndexParameters().Length == 0); + .Where(x => x.CanRead && x.CanWrite && x.GetGetMethod() != null && x.GetSetMethod() != null && x.GetIndexParameters().Length == 0); var memberInfos = (fieldInfos as IEnumerable).Concat(propertyInfos).ToArray(); @@ -372,6 +373,7 @@ public static SchemaDefinition Create(Type userType) break; default: + Contracts.Assert(false); throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo"); } diff --git a/src/Microsoft.ML/Data/TextLoader.cs b/src/Microsoft.ML/Data/TextLoader.cs index 20b3d72619..330412185e 100644 --- a/src/Microsoft.ML/Data/TextLoader.cs +++ b/src/Microsoft.ML/Data/TextLoader.cs @@ -74,12 +74,12 @@ public TextLoader CreateFrom(bool useHeader = false, { var userType = typeof(TInput); - var fieldInfos = userType.GetFields(); + var fieldInfos = userType.GetFields(BindingFlags.Public | BindingFlags.Instance); var propertyInfos = userType .GetProperties(BindingFlags.Public | BindingFlags.Instance) - .Where(x => x.CanRead && x.CanWrite && x.GetIndexParameters().Length == 0); + .Where(x => x.CanRead && x.CanWrite && x.GetGetMethod() != null && x.GetSetMethod() != null && x.GetIndexParameters().Length == 0); var memberInfos = (fieldInfos as IEnumerable).Concat(propertyInfos).ToArray(); @@ -121,6 +121,7 @@ public TextLoader CreateFrom(bool useHeader = false, break; default: + Contracts.Assert(false); throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo"); } diff --git a/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs b/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs index 2de80b79c2..14a7f473f7 100644 --- a/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs +++ b/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs @@ -331,7 +331,7 @@ public bool CompareObjectValues(object x, object y, Type type) public bool CompareThroughReflection(T x, T y) { - foreach (var field in typeof(T).GetFields()) + foreach (var field in typeof(T).GetFields(BindingFlags.Public | BindingFlags.Instance)) { var xvalue = field.GetValue(x); var yvalue = field.GetValue(y); @@ -346,8 +346,12 @@ public bool CompareThroughReflection(T x, T y) return false; } } - foreach (var property in typeof(T).GetProperties()) + foreach (var property in typeof(T).GetProperties(BindingFlags.Public | BindingFlags.Instance)) { + // Don't compare properties with private getters and setters + if (!property.CanRead || !property.CanWrite || property.GetGetMethod() == null || property.GetSetMethod() == null) + continue; + var xvalue = property.GetValue(x); var yvalue = property.GetValue(y); if (property.PropertyType.IsArray) @@ -576,10 +580,10 @@ public class ConversionLossMinValueClassProperties private long? _fLong; private short? _fShort; private sbyte? _fsByte; - public int? fInt { get { return _fInt; } set { _fInt = value; } } - public short? fShort { get { return _fShort; } set { _fShort = value; } } - public sbyte? fsByte { get { return _fsByte; } set { _fsByte = value; } } - public long? fLong { get { return _fLong; } set { _fLong = value; } } + public int? IntProp { get { return _fInt; } set { _fInt = value; } } + public short? ShortProp { get { return _fShort; } set { _fShort = value; } } + public sbyte? SByteProp { get { return _fsByte; } set { _fsByte = value; } } + public long? LongProp { get { return _fLong; } set { _fLong = value; } } } [Fact] @@ -590,8 +594,8 @@ public void ConversionMinValueToNullBehaviorProperties() var data = new List { - new ConversionLossMinValueClassProperties() { fsByte = null, fInt = null, fLong = null, fShort = null }, - new ConversionLossMinValueClassProperties() { fsByte = sbyte.MinValue, fInt = int.MinValue, fLong = long.MinValue, fShort = short.MinValue } + new ConversionLossMinValueClassProperties() { SByteProp = null, IntProp = null, LongProp = null, ShortProp = null }, + new ConversionLossMinValueClassProperties() { SByteProp = sbyte.MinValue, IntProp = int.MinValue, LongProp = long.MinValue, ShortProp = short.MinValue } }; foreach (var field in typeof(ConversionLossMinValueClassProperties).GetFields()) { @@ -599,8 +603,8 @@ public void ConversionMinValueToNullBehaviorProperties() var enumerator = dataView.AsEnumerable(env, false).GetEnumerator(); while (enumerator.MoveNext()) { - Assert.True(enumerator.Current.fInt == null && enumerator.Current.fLong == null && - enumerator.Current.fsByte == null && enumerator.Current.fShort == null); + Assert.True(enumerator.Current.IntProp == null && enumerator.Current.LongProp == null && + enumerator.Current.SByteProp == null && enumerator.Current.ShortProp == null); } } } @@ -640,7 +644,7 @@ public class ClassWithMixOfFieldsAndProperties { public string fString; private int _fInt; - public int fInt { get { return _fInt; } set { _fInt = value; } } + public int IntProp { get { return _fInt; } set { _fInt = value; } } } [Fact] @@ -648,9 +652,9 @@ public void ClassWithMixOfFieldsAndPropertiesConversion() { var data = new List() { - new ClassWithMixOfFieldsAndProperties(){ fInt=1, fString ="lala" }, - new ClassWithMixOfFieldsAndProperties(){ fInt=-1, fString ="" }, - new ClassWithMixOfFieldsAndProperties(){ fInt=0, fString =null } + new ClassWithMixOfFieldsAndProperties(){ IntProp=1, fString ="lala" }, + new ClassWithMixOfFieldsAndProperties(){ IntProp=-1, fString ="" }, + new ClassWithMixOfFieldsAndProperties(){ IntProp=0, fString =null } }; using (var env = new TlcEnvironment()) @@ -667,17 +671,67 @@ public void ClassWithMixOfFieldsAndPropertiesConversion() public abstract class BaseClassWithInheritedProperties { private string _fString; - public string fString { get { return _fString; } set { _fString = value; } } - public abstract long fLong { get; set; } + private byte _fByte; + public string StringProp { get { return _fString; } set { _fString = value; } } + public abstract long LongProp { get; set; } + public virtual byte ByteProp { get { return _fByte; } set { _fByte = value; } } } + public class ClassWithPrivateFieldsAndProperties + { + public ClassWithPrivateFieldsAndProperties() { seq++; _unusedStaticField++; _unusedPrivateField1 = 100; } + static public int seq; + static public int _unusedStaticField; + private int _unusedPrivateField1; + private string _fString; + + // This property is ignored because it has no setter + private int UnusedReadOnlyProperty { get { return _unusedPrivateField1; } } + + // This property is ignored because it is private + private int UnusedPrivateProperty { get { return _unusedPrivateField1; } set { _unusedPrivateField1 = value; } } + + // This property is ignored because it has a private setter + public int UnusedPropertyWithPrivateSetter { get { return _unusedPrivateField1; } private set { _unusedPrivateField1 = value; } } + + // This property is ignored because it has a private getter + public int UnusedPropertyWithPrivateGetter { private get { return _unusedPrivateField1; } set { _unusedPrivateField1 = value; } } + + public string StringProp { get { return _fString; } set { _fString = value; } } + } + + [Fact] + public void ClassWithPrivateFieldsAndPropertiesConversion() + { + var data = new List() + { + new ClassWithPrivateFieldsAndProperties(){ StringProp ="lala" }, + new ClassWithPrivateFieldsAndProperties(){ StringProp ="baba" } + }; + + using (var env = new TlcEnvironment()) + { + var dataView = ComponentCreation.CreateDataView(env, data); + var enumeratorSimple = dataView.AsEnumerable(env, false).GetEnumerator(); + var originalEnumerator = data.GetEnumerator(); + while (enumeratorSimple.MoveNext() && originalEnumerator.MoveNext()) + { + Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current)); + Assert.True(enumeratorSimple.Current.UnusedPropertyWithPrivateSetter == 100); + } + Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); + } + } + public class ClassWithInheritedProperties : BaseClassWithInheritedProperties { private int _fInt; private long _fLong; - public int fInt { get { return _fInt; } set { _fInt = value; } } - public override long fLong { get { return _fLong; } set { _fLong = value; } } + private byte _fByte2; + public int IntProp { get { return _fInt; } set { _fInt = value; } } + public override long LongProp { get { return _fLong; } set { _fLong = value; } } + public override byte ByteProp { get { return _fByte2; } set { _fByte2 = value; } } } [Fact] @@ -685,9 +739,9 @@ public void ClassWithInheritedPropertiesConversion() { var data = new List() { - new ClassWithInheritedProperties(){ fInt=1, fString ="lala", fLong=17 }, - new ClassWithInheritedProperties(){ fInt=-1, fString ="", fLong=2 }, - new ClassWithInheritedProperties(){ fInt=0, fString =null, fLong=18 } + new ClassWithInheritedProperties(){ IntProp=1, StringProp ="lala", LongProp=17, ByteProp=3 }, + new ClassWithInheritedProperties(){ IntProp=-1, StringProp ="", LongProp=2, ByteProp=4 }, + new ClassWithInheritedProperties(){ IntProp=0, StringProp =null, LongProp=18, ByteProp=5 } }; using (var env = new TlcEnvironment()) @@ -814,18 +868,18 @@ public class ClassWithArrayProperties private float[] _fFloat; private double[] _fDouble; private bool[] _fBool; - public string[] fString { get { return _fString; } set { _fString = value; } } - public int[] fInt { get { return _fInt; } set { _fInt = value; } } - public uint[] fuInt { get { return _fuInt; } set { _fuInt = value; } } - public short[] fShort { get { return _fShort; } set { _fShort = value; } } - public ushort[] fuShort { get { return _fuShort; } set { _fuShort = value; } } - public sbyte[] fsByte { get { return _fsByte; } set { _fsByte = value; } } - public byte[] fByte { get { return _fByte; } set { _fByte = value; } } - public long[] fLong { get { return _fLong; } set { _fLong = value; } } - public ulong[] fuLong { get { return _fuLong; } set { _fuLong = value; } } - public float[] fFloat { get { return _fFloat; } set { _fFloat = value; } } - public double[] fDouble { get { return _fDouble; } set { _fDouble = value; } } - public bool[] fBool { get { return _fBool; } set { _fBool = value; } } + public string[] StringProp { get { return _fString; } set { _fString = value; } } + public int[] IntProp { get { return _fInt; } set { _fInt = value; } } + public uint[] UIntProp { get { return _fuInt; } set { _fuInt = value; } } + public short[] ShortProp { get { return _fShort; } set { _fShort = value; } } + public ushort[] UShortProp { get { return _fuShort; } set { _fuShort = value; } } + public sbyte[] SByteProp { get { return _fsByte; } set { _fsByte = value; } } + public byte[] ByteProp { get { return _fByte; } set { _fByte = value; } } + public long[] LongProp { get { return _fLong; } set { _fLong = value; } } + public ulong[] ULongProp { get { return _fuLong; } set { _fuLong = value; } } + public float[] FloatProp { get { return _fFloat; } set { _fFloat = value; } } + public double[] DobuleProp { get { return _fDouble; } set { _fDouble = value; } } + public bool[] BoolProp { get { return _fBool; } set { _fBool = value; } } } public class ClassWithNullableArrayProperties @@ -843,18 +897,18 @@ public class ClassWithNullableArrayProperties private double?[] _fDouble; private bool?[] _fBool; - public string[] fString { get { return _fString; } set { _fString = value; } } - public int?[] fInt { get { return _fInt; } set { _fInt = value; } } - public uint?[] fuInt { get { return _fuInt; } set { _fuInt = value; } } - public short?[] fShort { get { return _fShort; } set { _fShort = value; } } - public ushort?[] fuShort { get { return _fuShort; } set { _fuShort = value; } } - public sbyte?[] fsByte { get { return _fsByte; } set { _fsByte = value; } } - public byte?[] fByte { get { return _fByte; } set { _fByte = value; } } - public long?[] fLong { get { return _fLong; } set { _fLong = value; } } - public ulong?[] fuLong { get { return _fuLong; } set { _fuLong = value; } } - public float?[] fFloat { get { return _fFloat; } set { _fFloat = value; } } - public double?[] fDouble { get { return _fDouble; } set { _fDouble = value; } } - public bool?[] fBool { get { return _fBool; } set { _fBool = value; } } + public string[] StringProp { get { return _fString; } set { _fString = value; } } + public int?[] IntProp { get { return _fInt; } set { _fInt = value; } } + public uint?[] UIntProp { get { return _fuInt; } set { _fuInt = value; } } + public short?[] ShortProp { get { return _fShort; } set { _fShort = value; } } + public ushort?[] UShortProp { get { return _fuShort; } set { _fuShort = value; } } + public sbyte?[] SByteProp { get { return _fsByte; } set { _fsByte = value; } } + public byte?[] ByteProp { get { return _fByte; } set { _fByte = value; } } + public long?[] LongProp { get { return _fLong; } set { _fLong = value; } } + public ulong?[] ULongProp { get { return _fuLong; } set { _fuLong = value; } } + public float?[] SingleProp { get { return _fFloat; } set { _fFloat = value; } } + public double?[] DoubleProp { get { return _fDouble; } set { _fDouble = value; } } + public bool?[] BoolProp { get { return _fBool; } set { _fBool = value; } } } [Fact] @@ -865,20 +919,20 @@ public void RoundTripConversionWithArrayPropertiess() { new ClassWithArrayProperties() { - fInt = new int[3] { 0, 1, 2 }, - fFloat = new float[3] { -0.99f, 0f, 0.99f }, - fString = new string[2] { "hola", "lola" }, - fBool = new bool[2] { true, false }, - fByte = new byte[3] { 0, 124, 255 }, - fDouble = new double[3] { -1, 0, 1 }, - fLong = new long[] { 0, 1, 2 }, - fsByte = new sbyte[3] { -127, 127, 0 }, - fShort = new short[3] { 0, 1225, 32767 }, - fuInt = new uint[2] { 0, uint.MaxValue }, - fuLong = new ulong[2] { ulong.MaxValue, 0 }, - fuShort = new ushort[2] { 0, ushort.MaxValue } + IntProp = new int[3] { 0, 1, 2 }, + FloatProp = new float[3] { -0.99f, 0f, 0.99f }, + StringProp = new string[2] { "hola", "lola" }, + BoolProp = new bool[2] { true, false }, + ByteProp = new byte[3] { 0, 124, 255 }, + DobuleProp = new double[3] { -1, 0, 1 }, + LongProp = new long[] { 0, 1, 2 }, + SByteProp = new sbyte[3] { -127, 127, 0 }, + ShortProp = new short[3] { 0, 1225, 32767 }, + UIntProp = new uint[2] { 0, uint.MaxValue }, + ULongProp = new ulong[2] { ulong.MaxValue, 0 }, + UShortProp = new ushort[2] { 0, ushort.MaxValue } }, - new ClassWithArrayProperties() { fInt = new int[3] { -2, 1, 0 }, fFloat = new float[3] { 0.99f, 0f, -0.99f }, fString = new string[2] { "", null } }, + new ClassWithArrayProperties() { IntProp = new int[3] { -2, 1, 0 }, FloatProp = new float[3] { 0.99f, 0f, -0.99f }, StringProp = new string[2] { "", null } }, new ClassWithArrayProperties() }; @@ -886,20 +940,20 @@ public void RoundTripConversionWithArrayPropertiess() { new ClassWithNullableArrayProperties() { - fInt = new int?[3] { null, -1, 1 }, - fFloat = new float?[3] { -0.99f, null, 0.99f }, - fString = new string[2] { null, "" }, - fBool = new bool?[3] { true, null, false }, - fByte = new byte?[4] { 0, 125, null, 255 }, - fDouble = new double?[3] { -1, null, 1 }, - fLong = new long?[] { null, -1, 1 }, - fsByte = new sbyte?[3] { -127, 127, null }, - fShort = new short?[3] { 0, null, 32767 }, - fuInt = new uint?[4] { null, 42, 0, uint.MaxValue }, - fuLong = new ulong?[3] { ulong.MaxValue, null, 0 }, - fuShort = new ushort?[3] { 0, null, ushort.MaxValue } + IntProp = new int?[3] { null, -1, 1 }, + SingleProp = new float?[3] { -0.99f, null, 0.99f }, + StringProp = new string[2] { null, "" }, + BoolProp = new bool?[3] { true, null, false }, + ByteProp = new byte?[4] { 0, 125, null, 255 }, + DoubleProp = new double?[3] { -1, null, 1 }, + LongProp = new long?[] { null, -1, 1 }, + SByteProp = new sbyte?[3] { -127, 127, null }, + ShortProp = new short?[3] { 0, null, 32767 }, + UIntProp = new uint?[4] { null, 42, 0, uint.MaxValue }, + ULongProp = new ulong?[3] { ulong.MaxValue, null, 0 }, + UShortProp = new ushort?[3] { 0, null, ushort.MaxValue } }, - new ClassWithNullableArrayProperties() { fInt = new int?[3] { -2, 1, 0 }, fFloat = new float?[3] { 0.99f, 0f, -0.99f }, fString = new string[2] { "lola", "hola" } }, + new ClassWithNullableArrayProperties() { IntProp = new int?[3] { -2, 1, 0 }, SingleProp = new float?[3] { 0.99f, 0f, -0.99f }, StringProp = new string[2] { "lola", "hola" } }, new ClassWithNullableArrayProperties() }; From dd4efe48b2836f41ea7fd35da185760562259dc2 Mon Sep 17 00:00:00 2001 From: Don Syme Date: Wed, 1 Aug 2018 19:52:52 +0100 Subject: [PATCH 14/15] add more tests --- test/Microsoft.ML.FSharp.Tests/SmokeTests.fs | 66 ++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs index a677b131a3..570c1e0722 100644 --- a/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs +++ b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs @@ -195,3 +195,69 @@ module SmokeTest2 = let predictionResults = [ for p in predictions -> p.Sentiment ] Assert.Equal(predictionResults, [ false; true; true ]) +module SmokeTest3 = + + type SentimentData() = + [] + member val SentimentText = "" with get, set + + [] + member val Sentiment = 0.0 with get, set + + type SentimentPrediction() = + [] + member val Sentiment = false with get, set + + [] + let ``FSharp-Sentiment-Smoke-Test`` () = + + // See https://github.com/dotnet/machinelearning/issues/401: forces the loading of ML.NET component assemblies + let _load = + [ typeof; + typeof ] + + let testDataPath = __SOURCE_DIRECTORY__ + @"/../data/wikipedia-detox-250-line-data.tsv" + + let pipeline = LearningPipeline() + + pipeline.Add( + TextLoader(testDataPath).CreateFrom( + Arguments = + TextLoaderArguments( + HasHeader = true, + Column = [| TextLoaderColumn(Name = "Label", + Source = [| TextLoaderRange(0) |], + Type = Nullable (Data.DataKind.Num)) + TextLoaderColumn(Name = "SentimentText", + Source = [| TextLoaderRange(1) |], + Type = Nullable (Data.DataKind.Text)) |] + ))) + + pipeline.Add( + TextFeaturizer( + "Features", [| "SentimentText" |], + KeepDiacritics = false, + KeepPunctuations = false, + TextCase = TextNormalizerTransformCaseNormalizationMode.Lower, + OutputTokens = true, + VectorNormalizer = TextTransformTextNormKind.L2 + )) + + pipeline.Add( + FastTreeBinaryClassifier( + NumLeaves = 5, + NumTrees = 5, + MinDocumentsInLeafs = 2 + )) + + let model = pipeline.Train() + + let predictions = + [ SentimentData(SentimentText = "This is a gross exaggeration. Nobody is setting a kangaroo court. There was a simple addition.") + SentimentData(SentimentText = "Sort of ok") + SentimentData(SentimentText = "Joe versus the Volcano Coffee Company is a great film.") ] + |> model.Predict + + let predictionResults = [ for p in predictions -> p.Sentiment ] + Assert.Equal(predictionResults, [ false; true; true ]) + From 3d5355904491d6d86bb94c7b97a568ff7ae2e8f1 Mon Sep 17 00:00:00 2001 From: Don Syme Date: Thu, 2 Aug 2018 11:30:24 +0100 Subject: [PATCH 15/15] Update ApiUtils.cs --- src/Microsoft.ML.Api/ApiUtils.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.ML.Api/ApiUtils.cs b/src/Microsoft.ML.Api/ApiUtils.cs index 22afda7608..96e821f16e 100644 --- a/src/Microsoft.ML.Api/ApiUtils.cs +++ b/src/Microsoft.ML.Api/ApiUtils.cs @@ -109,7 +109,7 @@ private static Delegate GeneratePeek(PropertyInfo propertyIn il.Emit(OpCodes.Ldarg_3); // push arg3 il.Emit(OpCodes.Ldarg_1); // push arg1 - il.Emit(opcode, minfo); // push [stack top].[propertyInfo] + il.Emit(opcode, minfo); // call [stack top].get_[propertyInfo]() // Stobj needs to coupled with a type. if (assignmentOpCode == OpCodes.Stobj) // [stack top-1] = [stack top] il.Emit(assignmentOpCode, propertyInfo.PropertyType); @@ -181,7 +181,7 @@ private static Delegate GeneratePoke(PropertyInfo propertyIn il.Emit(OpCodes.Ldarg_1); // push arg1 il.Emit(OpCodes.Ldarg_2); // push arg2 - il.Emit(opcode, minfo); // [stack top-1].[propertyInfo] <- [stack top] + il.Emit(opcode, minfo); // call [stack top-1].set_[propertyInfo]([stack top]) il.Emit(OpCodes.Ret); // ret return mb.CreateDelegate(typeof(Poke), null); }