Skip to content

Allow use of property-based row classes in ML.NET #616

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Aug 2, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions Microsoft.ML.sln
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CodeAnalyzer",
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CodeAnalyzer.Tests", "test\Microsoft.ML.CodeAnalyzer.Tests\Microsoft.ML.CodeAnalyzer.Tests.csproj", "{3E4ABF07-7970-4BE6-B45B-A13D3C397545}"
EndProject
Project("{F2A71F9B-5D33-465A-A702-920D77279786}") = "Microsoft.ML.FSharp.Tests", "test\Microsoft.ML.FSharp.Tests\Microsoft.ML.FSharp.Tests.fsproj", "{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.ImageAnalytics", "src\Microsoft.ML.ImageAnalytics\Microsoft.ML.ImageAnalytics.csproj", "{00E38F77-1E61-4CDF-8F97-1417D4E85053}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.HalLearners", "src\Microsoft.ML.HalLearners\Microsoft.ML.HalLearners.csproj", "{A7222F41-1CF0-47D9-B80C-B4D77B027A61}"
Expand Down Expand Up @@ -333,6 +335,14 @@ Global
{3E4ABF07-7970-4BE6-B45B-A13D3C397545}.Release|Any CPU.Build.0 = Release|Any CPU
{3E4ABF07-7970-4BE6-B45B-A13D3C397545}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
{3E4ABF07-7970-4BE6-B45B-A13D3C397545}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Debug|Any CPU.Build.0 = Debug|Any CPU
{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Release|Any CPU.ActiveCfg = Release|Any CPU
{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Release|Any CPU.Build.0 = Release|Any CPU
{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
{00E38F77-1E61-4CDF-8F97-1417D4E85053}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{00E38F77-1E61-4CDF-8F97-1417D4E85053}.Debug|Any CPU.Build.0 = Debug|Any CPU
{00E38F77-1E61-4CDF-8F97-1417D4E85053}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
Expand Down Expand Up @@ -387,6 +397,7 @@ Global
{BF66A305-DF10-47E4-8D81-42049B149D2B} = {D3D38B03-B557-484D-8348-8BADEE4DF592}
{B4E55B2D-2A92-46E7-B72F-E76D6FD83440} = {7F13E156-3EBA-4021-84A5-CD56BA72F99E}
{3E4ABF07-7970-4BE6-B45B-A13D3C397545} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{00E38F77-1E61-4CDF-8F97-1417D4E85053} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{A7222F41-1CF0-47D9-B80C-B4D77B027A61} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
EndGlobalSection
Expand Down
102 changes: 86 additions & 16 deletions src/Microsoft.ML.Api/ApiUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,31 @@ private static OpCode GetAssignmentOpCode(Type t)
/// </summary>
internal static Delegate GeneratePeek<TOwn, TRow>(InternalSchemaDefinition.Column column)
{
var fieldInfo = column.FieldInfo;
Type fieldType = fieldInfo.FieldType;

var assignmentOpCode = GetAssignmentOpCode(fieldType);
Func<FieldInfo, OpCode, Delegate> func = GeneratePeek<TOwn, TRow, int>;
var methInfo = func.GetMethodInfo().GetGenericMethodDefinition()
.MakeGenericMethod(typeof(TOwn), typeof(TRow), fieldType);
return (Delegate)methInfo.Invoke(null, new object[] { fieldInfo, assignmentOpCode });
switch (column.MemberInfo)
{
case FieldInfo fieldInfo:
Type fieldType = fieldInfo.FieldType;

var assignmentOpCode = GetAssignmentOpCode(fieldType);
Func<FieldInfo, OpCode, Delegate> func = GeneratePeek<TOwn, TRow, int>;
var methInfo = func.GetMethodInfo().GetGenericMethodDefinition()
.MakeGenericMethod(typeof(TOwn), typeof(TRow), fieldType);
return (Delegate)methInfo.Invoke(null, new object[] { fieldInfo, assignmentOpCode });

case PropertyInfo propertyInfo:
Type propertyType = propertyInfo.PropertyType;

var assignmentOpCodeProp = GetAssignmentOpCode(propertyType);
Func<PropertyInfo, OpCode, Delegate> funcProp = GeneratePeek<TOwn, TRow, int>;
var methInfoProp = funcProp.GetMethodInfo().GetGenericMethodDefinition()
.MakeGenericMethod(typeof(TOwn), typeof(TRow), propertyType);
return (Delegate)methInfoProp.Invoke(null, new object[] { propertyInfo, assignmentOpCodeProp });

default:
Contracts.Assert(false);
throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo");

}
}

private static Delegate GeneratePeek<TOwn, TRow, TValue>(FieldInfo fieldInfo, OpCode assignmentOpCode)
Expand All @@ -81,21 +98,59 @@ private static Delegate GeneratePeek<TOwn, TRow, TValue>(FieldInfo fieldInfo, Op
return mb.CreateDelegate(typeof(Peek<TRow, TValue>));
}

private static Delegate GeneratePeek<TOwn, TRow, TValue>(PropertyInfo propertyInfo, OpCode assignmentOpCode)
{
// REVIEW: It seems like we really should cache these, instead of generating them per cursor.
Type[] args = { typeof(TOwn), typeof(TRow), typeof(long), typeof(TValue).MakeByRefType() };
var mb = new DynamicMethod("Peek", null, args, typeof(TOwn), true);
var il = mb.GetILGenerator();
var minfo = propertyInfo.GetGetMethod();
var opcode = (minfo.IsVirtual || minfo.IsAbstract) ? OpCodes.Callvirt : OpCodes.Call;

il.Emit(OpCodes.Ldarg_3); // push arg3
il.Emit(OpCodes.Ldarg_1); // push arg1
il.Emit(opcode, minfo); // call [stack top].get_[propertyInfo]()
// Stobj needs to coupled with a type.
if (assignmentOpCode == OpCodes.Stobj) // [stack top-1] = [stack top]
il.Emit(assignmentOpCode, propertyInfo.PropertyType);
else
il.Emit(assignmentOpCode);
il.Emit(OpCodes.Ret); // ret

return mb.CreateDelegate(typeof(Peek<TRow, TValue>));
}

/// <summary>
/// Each of the specialized 'poke' methods sets the appropriate field value of an instance of T
/// to the provided value. So, the call is 'peek(userObject, providedValue)' and the logic is
/// indentical to 'userObject.##FIELD## = providedValue', where ##FIELD## is defined per poke method.
/// </summary>
internal static Delegate GeneratePoke<TOwn, TRow>(InternalSchemaDefinition.Column column)
{
var fieldInfo = column.FieldInfo;
Type fieldType = fieldInfo.FieldType;

var assignmentOpCode = GetAssignmentOpCode(fieldType);
Func<FieldInfo, OpCode, Delegate> func = GeneratePoke<TOwn, TRow, int>;
var methInfo = func.GetMethodInfo().GetGenericMethodDefinition()
.MakeGenericMethod(typeof(TOwn), typeof(TRow), fieldType);
return (Delegate)methInfo.Invoke(null, new object[] { fieldInfo, assignmentOpCode });
switch (column.MemberInfo)
{
case FieldInfo fieldInfo:
Type fieldType = fieldInfo.FieldType;

var assignmentOpCode = GetAssignmentOpCode(fieldType);
Func<FieldInfo, OpCode, Delegate> func = GeneratePoke<TOwn, TRow, int>;
var methInfo = func.GetMethodInfo().GetGenericMethodDefinition()
.MakeGenericMethod(typeof(TOwn), typeof(TRow), fieldType);
return (Delegate)methInfo.Invoke(null, new object[] { fieldInfo, assignmentOpCode });

case PropertyInfo propertyInfo:
Type propertyType = propertyInfo.PropertyType;

var assignmentOpCodeProp = GetAssignmentOpCode(propertyType);
Func<PropertyInfo, Delegate> funcProp = GeneratePoke<TOwn, TRow, int>;
var methInfoProp = funcProp.GetMethodInfo().GetGenericMethodDefinition()
.MakeGenericMethod(typeof(TOwn), typeof(TRow), propertyType);
return (Delegate)methInfoProp.Invoke(null, new object[] { propertyInfo });

default:
Contracts.Assert(false);
throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo");
}
}

private static Delegate GeneratePoke<TOwn, TRow, TValue>(FieldInfo fieldInfo, OpCode assignmentOpCode)
Expand All @@ -115,5 +170,20 @@ private static Delegate GeneratePoke<TOwn, TRow, TValue>(FieldInfo fieldInfo, Op
il.Emit(OpCodes.Ret); // ret
return mb.CreateDelegate(typeof(Poke<TRow, TValue>), null);
}

private static Delegate GeneratePoke<TOwn, TRow, TValue>(PropertyInfo propertyInfo)
{
Type[] args = { typeof(TOwn), typeof(TRow), typeof(TValue) };
var mb = new DynamicMethod("Poke", null, args, typeof(TOwn), true);
var il = mb.GetILGenerator();
var minfo = propertyInfo.GetSetMethod();
var opcode = (minfo.IsVirtual || minfo.IsAbstract) ? OpCodes.Callvirt : OpCodes.Call;

il.Emit(OpCodes.Ldarg_1); // push arg1
il.Emit(OpCodes.Ldarg_2); // push arg2
il.Emit(opcode, minfo); // call [stack top-1].set_[propertyInfo]([stack top])
il.Emit(OpCodes.Ret); // ret
return mb.CreateDelegate(typeof(Poke<TRow, TValue>), null);
}
}
}
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Api/DataViewConstructionUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ private Delegate CreateGetter(int index)
var colType = DataView.Schema.GetColumnType(index);

var column = DataView._schema.SchemaDefn.Columns[index];
var outputType = column.IsComputed ? column.ReturnType : column.FieldInfo.FieldType;
var outputType = column.OutputType;
var genericType = outputType;
Func<int, Delegate> del;

Expand Down
71 changes: 43 additions & 28 deletions src/Microsoft.ML.Api/InternalSchemaDefinition.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,23 @@ internal sealed class InternalSchemaDefinition
public class Column
{
public readonly string ColumnName;
public readonly FieldInfo FieldInfo;
public readonly MemberInfo MemberInfo;
public readonly ParameterInfo ReturnParameterInfo;
public readonly ColumnType ColumnType;
public readonly bool IsComputed;
public readonly Delegate Generator;
private readonly Dictionary<string, MetadataInfo> _metadata;
public Dictionary<string, MetadataInfo> Metadata { get { return _metadata; } }
public Type ReturnType {get { return ReturnParameterInfo.ParameterType.GetElementType(); }}
public Type ComputedReturnType {get { return ReturnParameterInfo.ParameterType.GetElementType(); }}
public Type FieldOrPropertyType => (MemberInfo is FieldInfo) ? (MemberInfo as FieldInfo).FieldType : (MemberInfo as PropertyInfo).PropertyType;
public Type OutputType => IsComputed ? ComputedReturnType : FieldOrPropertyType;

public Column(string columnName, ColumnType columnType, FieldInfo fieldInfo) :
this(columnName, columnType, fieldInfo, null, null) { }
public Column(string columnName, ColumnType columnType, MemberInfo memberInfo) :
this(columnName, columnType, memberInfo, null, null) { }

public Column(string columnName, ColumnType columnType, FieldInfo fieldInfo,
public Column(string columnName, ColumnType columnType, MemberInfo memberInfo,
Dictionary<string, MetadataInfo> metadataInfos) :
this(columnName, columnType, fieldInfo, null, metadataInfos) { }
this(columnName, columnType, memberInfo, null, metadataInfos) { }

public Column(string columnName, ColumnType columnType, Delegate generator) :
this(columnName, columnType, null, generator, null) { }
Expand All @@ -46,7 +48,7 @@ public Column(string columnName, ColumnType columnType, Delegate generator,
Dictionary<string, MetadataInfo> metadataInfos) :
this(columnName, columnType, null, generator, metadataInfos) { }

private Column(string columnName, ColumnType columnType, FieldInfo fieldInfo = null,
private Column(string columnName, ColumnType columnType, MemberInfo memberInfo = null,
Delegate generator = null, Dictionary<string, MetadataInfo> metadataInfos = null)
{
Contracts.AssertNonEmpty(columnName);
Expand All @@ -55,8 +57,8 @@ private Column(string columnName, ColumnType columnType, FieldInfo fieldInfo = n

if (generator == null)
{
Contracts.AssertValue(fieldInfo);
FieldInfo = fieldInfo;
Contracts.AssertValue(memberInfo);
MemberInfo = memberInfo;
}
else
{
Expand Down Expand Up @@ -95,8 +97,8 @@ public void AssertRep()
// If Column is computed type, it must have a generator.
Contracts.Assert(IsComputed == (Generator != null));

// Column must have either a generator or a fieldInfo value.
Contracts.Assert((Generator == null) != (FieldInfo == null));
// Column must have either a generator or a memberInfo value.
Contracts.Assert((Generator == null) != (MemberInfo == null));

// Additional Checks if there is a generator.
if (Generator == null)
Expand All @@ -115,9 +117,7 @@ public void AssertRep()
Contracts.Assert(Generator.GetMethodInfo().ReturnType == typeof(void));

// Checks that the return type of the generator is compatible with ColumnType.
bool isVector;
DataKind datakind;
GetVectorAndKind(ReturnType, "return type", out isVector, out datakind);
GetVectorAndKind(ComputedReturnType, "return type", out bool isVector, out DataKind datakind);
Contracts.Assert(isVector == ColumnType.IsVector);
Contracts.Assert(datakind == ColumnType.ItemType.RawKind);
}
Expand All @@ -131,19 +131,30 @@ private InternalSchemaDefinition(Column[] columns)
}

/// <summary>
/// Given a field info on a type, returns whether this appears to be a vector type,
/// Given a field or property info on a type, returns whether this appears to be a vector type,
/// and also the associated data kind for this type. If a data kind could not
/// be determined, this will throw.
/// </summary>
/// <param name="fieldInfo">The field info to inspect.</param>
/// <param name="memberInfo">The field or property info to inspect.</param>
/// <param name="isVector">Whether this appears to be a vector type.</param>
/// <param name="kind">The data kind of the type, or items of this type if vector.</param>
public static void GetVectorAndKind(FieldInfo fieldInfo, out bool isVector, out DataKind kind)
public static void GetVectorAndKind(MemberInfo memberInfo, out bool isVector, out DataKind kind)
{
Contracts.AssertValue(fieldInfo);
Type rawFieldType = fieldInfo.FieldType;
var name = fieldInfo.Name;
GetVectorAndKind(rawFieldType, name, out isVector, out kind);
Contracts.AssertValue(memberInfo);
switch (memberInfo)
{
case FieldInfo fieldInfo:
GetVectorAndKind(fieldInfo.FieldType, fieldInfo.Name, out isVector, out kind);
break;

case PropertyInfo propertyInfo:
GetVectorAndKind(propertyInfo.PropertyType, propertyInfo.Name, out isVector, out kind);
break;

default:
Contracts.Assert(false);
throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo");
}
}

/// <summary>
Expand Down Expand Up @@ -211,23 +222,27 @@ public static InternalSchemaDefinition Create(Type userType, SchemaDefinition us

bool isVector;
DataKind kind;
FieldInfo fieldInfo = null;
MemberInfo memberInfo = null;

if (!col.IsComputed)
{
fieldInfo = userType.GetField(col.MemberName);
memberInfo = userType.GetField(col.MemberName);

if (memberInfo == null)
memberInfo = userType.GetProperty(col.MemberName);

if (fieldInfo == null)
throw Contracts.ExceptParam(nameof(userSchemaDefinition), "No field with name '{0}' found in type '{1}'",
if (memberInfo == null)
throw Contracts.ExceptParam(nameof(userSchemaDefinition), "No field or property with name '{0}' found in type '{1}'",
col.MemberName,
userType.FullName);

//Clause to handle the field that may be used to expose the cursor channel.
//This field does not need a column.
if (fieldInfo.FieldType == typeof(IChannel))
if ( (memberInfo is FieldInfo && (memberInfo as FieldInfo).FieldType == typeof(IChannel)) ||
(memberInfo is PropertyInfo && (memberInfo as PropertyInfo).PropertyType == typeof(IChannel)))
continue;

GetVectorAndKind(fieldInfo, out isVector, out kind);
GetVectorAndKind(memberInfo, out isVector, out kind);
}
else
{
Expand Down Expand Up @@ -268,7 +283,7 @@ public static InternalSchemaDefinition Create(Type userType, SchemaDefinition us

dstCols[i] = col.IsComputed ?
new Column(colName, colType, col.Generator, col.Metadata)
: new Column(colName, colType, fieldInfo, col.Metadata);
: new Column(colName, colType, memberInfo, col.Metadata);

}
return new InternalSchemaDefinition(dstCols);
Expand Down
Loading