Skip to content

Commit e34f2e1

Browse files
tannergoodingEric Erhardt
authored and
Eric Erhardt
committed
Updating DatabaseLoader to support getting column info from a given .NET type. (dotnet#4091)
* Updating DatabaseLoader to support getting column info from a given .NET type. * Addressing PR feedback. * Reverting DatabaseLoaderTests.IrisLightGbm to manually creating the columns. * Fixing a naming error: source to sources
1 parent e389752 commit e34f2e1

File tree

6 files changed

+269
-48
lines changed

6 files changed

+269
-48
lines changed

src/Microsoft.ML.Data/Data/SchemaDefinition.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ public sealed class ColumnNameAttribute : Attribute
104104
/// <summary>
105105
/// Column name.
106106
/// </summary>
107+
[BestFriend]
107108
internal string Name { get; }
108109

109110
/// <summary>

src/Microsoft.ML.Data/DataLoadSave/Text/LoadColumnAttribute.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ public LoadColumnAttribute(int[] columnIndexes)
4646
Sources.Add(new TextLoader.Range(col));
4747
}
4848

49+
[BestFriend]
4950
internal List<TextLoader.Range> Sources;
5051
}
5152
}

src/Microsoft.ML.Experimental/DataLoadSave/Database/DatabaseLoader.cs

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
using System;
66
using System.Collections.Generic;
77
using System.Data;
8+
using System.Data.Common;
89
using System.Linq;
10+
using System.Reflection;
11+
using System.Runtime.CompilerServices;
12+
using System.Text;
913
using Microsoft.ML;
1014
using Microsoft.ML.CommandLine;
1115
using Microsoft.ML.Data;
@@ -98,6 +102,71 @@ void ICanSaveModel.Save(ModelSaveContext ctx)
98102
/// <param name="source">The source from which to load data.</param>
99103
public IDataView Load(DatabaseSource source) => new BoundLoader(this, source);
100104

105+
internal static DatabaseLoader CreateDatabaseLoader<TInput>(IHostEnvironment host)
106+
{
107+
var userType = typeof(TInput);
108+
109+
var fieldInfos = userType.GetFields(BindingFlags.Public | BindingFlags.Instance);
110+
111+
var propertyInfos =
112+
userType
113+
.GetProperties(BindingFlags.Public | BindingFlags.Instance)
114+
.Where(x => x.CanRead && x.GetGetMethod() != null && x.GetIndexParameters().Length == 0);
115+
116+
var memberInfos = (fieldInfos as IEnumerable<MemberInfo>).Concat(propertyInfos).ToArray();
117+
118+
if (memberInfos.Length == 0)
119+
throw host.ExceptParam(nameof(TInput), $"Should define at least one public, readable field or property in {nameof(TInput)}.");
120+
121+
var columns = new List<Column>();
122+
123+
for (int index = 0; index < memberInfos.Length; index++)
124+
{
125+
var memberInfo = memberInfos[index];
126+
var mappingAttrName = memberInfo.GetCustomAttribute<ColumnNameAttribute>();
127+
128+
var column = new Column();
129+
column.Name = mappingAttrName?.Name ?? memberInfo.Name;
130+
131+
var mappingAttr = memberInfo.GetCustomAttribute<LoadColumnAttribute>();
132+
133+
if (mappingAttr is object)
134+
{
135+
var sources = mappingAttr.Sources.Select((source) => Range.FromTextLoaderRange(source)).ToArray();
136+
column.Source = sources.Single().Min;
137+
}
138+
139+
InternalDataKind dk;
140+
switch (memberInfo)
141+
{
142+
case FieldInfo field:
143+
if (!InternalDataKindExtensions.TryGetDataKind(field.FieldType.IsArray ? field.FieldType.GetElementType() : field.FieldType, out dk))
144+
throw Contracts.Except($"Field {memberInfo.Name} is of unsupported type.");
145+
146+
break;
147+
148+
case PropertyInfo property:
149+
if (!InternalDataKindExtensions.TryGetDataKind(property.PropertyType.IsArray ? property.PropertyType.GetElementType() : property.PropertyType, out dk))
150+
throw Contracts.Except($"Property {memberInfo.Name} is of unsupported type.");
151+
break;
152+
153+
default:
154+
Contracts.Assert(false);
155+
throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo");
156+
}
157+
158+
column.Type = dk.ToDbType();
159+
160+
columns.Add(column);
161+
}
162+
163+
var options = new Options
164+
{
165+
Columns = columns.ToArray()
166+
};
167+
return new DatabaseLoader(host, options);
168+
}
169+
101170
/// <summary>
102171
/// Describes how an input column should be mapped to an <see cref="IDataView"/> column.
103172
/// </summary>
@@ -128,6 +197,86 @@ public sealed class Column
128197
public KeyCount KeyCount;
129198
}
130199

200+
/// <summary>
201+
/// Specifies the range of indices of input columns that should be mapped to an output column.
202+
/// </summary>
203+
public sealed class Range
204+
{
205+
public Range() { }
206+
207+
/// <summary>
208+
/// A range representing a single value. Will result in a scalar column.
209+
/// </summary>
210+
/// <param name="index">The index of the field of the text file to read.</param>
211+
public Range(int index)
212+
{
213+
Contracts.CheckParam(index >= 0, nameof(index), "Must be non-negative");
214+
Min = index;
215+
Max = index;
216+
}
217+
218+
/// <summary>
219+
/// A range representing a set of values. Will result in a vector column.
220+
/// </summary>
221+
/// <param name="min">The minimum inclusive index of the column.</param>
222+
/// <param name="max">The maximum-inclusive index of the column. If <c>null</c>
223+
/// indicates that the <see cref="TextLoader"/> should auto-detect the legnth
224+
/// of the lines, and read untill the end.</param>
225+
public Range(int min, int? max)
226+
{
227+
Contracts.CheckParam(min >= 0, nameof(min), "Must be non-negative");
228+
Contracts.CheckParam(!(max < min), nameof(max), "If specified, must be greater than or equal to " + nameof(min));
229+
230+
Min = min;
231+
Max = max;
232+
// Note that without the following being set, in the case where there is a single range
233+
// where Min == Max, the result will not be a vector valued but a scalar column.
234+
ForceVector = true;
235+
AutoEnd = max == null;
236+
}
237+
238+
/// <summary>
239+
/// The minimum index of the column, inclusive.
240+
/// </summary>
241+
[Argument(ArgumentType.Required, HelpText = "First index in the range")]
242+
public int Min;
243+
244+
/// <summary>
245+
/// The maximum index of the column, inclusive. If <see langword="null"/>
246+
/// indicates that the <see cref="TextLoader"/> should auto-detect the legnth
247+
/// of the lines, and read untill the end.
248+
/// If <see cref="Max"/> is specified, the field <see cref="AutoEnd"/> is ignored.
249+
/// </summary>
250+
[Argument(ArgumentType.AtMostOnce, HelpText = "Last index in the range")]
251+
public int? Max;
252+
253+
/// <summary>
254+
/// Whether this range extends to the end of the line, but should be a fixed number of items.
255+
/// If <see cref="Max"/> is specified, the field <see cref="AutoEnd"/> is ignored.
256+
/// </summary>
257+
[Argument(ArgumentType.AtMostOnce,
258+
HelpText = "This range extends to the end of the line, but should be a fixed number of items",
259+
ShortName = "auto")]
260+
public bool AutoEnd;
261+
262+
/// <summary>
263+
/// Whether this range includes only other indices not specified.
264+
/// </summary>
265+
[Argument(ArgumentType.AtMostOnce, HelpText = "This range includes only other indices not specified", ShortName = "other")]
266+
public bool AllOther;
267+
268+
/// <summary>
269+
/// Force scalar columns to be treated as vectors of length one.
270+
/// </summary>
271+
[Argument(ArgumentType.AtMostOnce, HelpText = "Force scalar columns to be treated as vectors of length one", ShortName = "vector")]
272+
public bool ForceVector;
273+
274+
internal static Range FromTextLoaderRange(TextLoader.Range range)
275+
{
276+
return new Range(range.Min, range.Max);
277+
}
278+
}
279+
131280
/// <summary>
132281
/// The settings for <see cref="DatabaseLoader"/>
133282
/// </summary>

src/Microsoft.ML.Experimental/DataLoadSave/Database/DatabaseLoaderCatalog.cs

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@ namespace Microsoft.ML
1111
/// </summary>
1212
public static class DatabaseLoaderCatalog
1313
{
14-
/// <summary>
15-
/// Create a database loader <see cref="DatabaseLoader"/>.
16-
/// </summary>
14+
/// <summary>Create a database loader <see cref="DatabaseLoader"/>.</summary>
1715
/// <param name="catalog">The <see cref="DataOperationsCatalog"/> catalog.</param>
1816
/// <param name="columns">Array of columns <see cref="DatabaseLoader.Column"/> defining the schema.</param>
1917
public static DatabaseLoader CreateDatabaseLoader(this DataOperationsCatalog catalog,
@@ -23,8 +21,22 @@ public static DatabaseLoader CreateDatabaseLoader(this DataOperationsCatalog cat
2321
{
2422
Columns = columns,
2523
};
26-
27-
return new DatabaseLoader(CatalogUtils.GetEnvironment(catalog), options);
24+
return catalog.CreateDatabaseLoader(options);
2825
}
26+
27+
/// <summary>Create a database loader <see cref="DatabaseLoader"/>.</summary>
28+
/// <param name="catalog">The <see cref="DataOperationsCatalog"/> catalog.</param>
29+
/// <param name="options">Defines the settings of the load operation.</param>
30+
public static DatabaseLoader CreateDatabaseLoader(this DataOperationsCatalog catalog,
31+
DatabaseLoader.Options options)
32+
=> new DatabaseLoader(CatalogUtils.GetEnvironment(catalog), options);
33+
34+
/// <summary>Create a database loader <see cref="DatabaseLoader"/>.</summary>
35+
/// <typeparam name="TInput">Defines the schema of the data to be loaded. Use public fields or properties
36+
/// decorated with <see cref="LoadColumnAttribute"/> (and possibly other attributes) to specify the column
37+
/// names and their data types in the schema of the loaded data.</typeparam>
38+
/// <param name="catalog">The <see cref="DataOperationsCatalog"/> catalog.</param>
39+
public static DatabaseLoader CreateDatabaseLoader<TInput>(this DataOperationsCatalog catalog)
40+
=> DatabaseLoader.CreateDatabaseLoader<TInput>(CatalogUtils.GetEnvironment(catalog));
2941
}
3042
}

src/Microsoft.ML.Experimental/DataLoadSave/Database/DbExtensions.cs

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,5 +66,82 @@ public static Type ToType(this DbType dbType)
6666
return null;
6767
}
6868
}
69+
70+
/// <summary>Maps a <see cref="InternalDataKind"/> to the associated <see cref="DbType"/>.</summary>
71+
public static DbType ToDbType(this InternalDataKind dataKind)
72+
{
73+
switch (dataKind)
74+
{
75+
case InternalDataKind.I1:
76+
{
77+
return DbType.SByte;
78+
}
79+
80+
case InternalDataKind.U1:
81+
{
82+
return DbType.Byte;
83+
}
84+
85+
case InternalDataKind.I2:
86+
{
87+
return DbType.Int16;
88+
}
89+
90+
case InternalDataKind.U2:
91+
{
92+
return DbType.UInt16;
93+
}
94+
95+
case InternalDataKind.I4:
96+
{
97+
return DbType.Int32;
98+
}
99+
100+
case InternalDataKind.U4:
101+
{
102+
return DbType.UInt32;
103+
}
104+
105+
case InternalDataKind.I8:
106+
{
107+
return DbType.Int64;
108+
}
109+
110+
case InternalDataKind.U8:
111+
{
112+
return DbType.UInt64;
113+
}
114+
115+
case InternalDataKind.R4:
116+
{
117+
return DbType.Single;
118+
}
119+
120+
case InternalDataKind.R8:
121+
{
122+
return DbType.Double;
123+
}
124+
125+
case InternalDataKind.TX:
126+
{
127+
return DbType.String;
128+
}
129+
130+
case InternalDataKind.BL:
131+
{
132+
return DbType.Boolean;
133+
}
134+
135+
case InternalDataKind.DT:
136+
{
137+
return DbType.DateTime;
138+
}
139+
140+
default:
141+
{
142+
throw new NotSupportedException();
143+
}
144+
}
145+
}
69146
}
70147
}

0 commit comments

Comments
 (0)