Skip to content

Commit 429f8cc

Browse files
authored
Fix TextLoader constructor and add exception message (#3788)
1 parent 6722dbf commit 429f8cc

File tree

3 files changed

+102
-2
lines changed

3 files changed

+102
-2
lines changed

src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1456,10 +1456,13 @@ internal static TextLoader CreateTextLoader<TInput>(IHostEnvironment host,
14561456
var propertyInfos =
14571457
userType
14581458
.GetProperties(BindingFlags.Public | BindingFlags.Instance)
1459-
.Where(x => x.CanRead && x.CanWrite && x.GetGetMethod() != null && x.GetSetMethod() != null && x.GetIndexParameters().Length == 0);
1459+
.Where(x => x.CanRead && x.GetGetMethod() != null && x.GetIndexParameters().Length == 0);
14601460

14611461
var memberInfos = (fieldInfos as IEnumerable<MemberInfo>).Concat(propertyInfos).ToArray();
14621462

1463+
if (memberInfos.Length == 0)
1464+
throw host.ExceptParam(nameof(TInput), $"Should define at least one public, readable field or property in {nameof(TInput)}.");
1465+
14631466
var columns = new List<Column>();
14641467

14651468
for (int index = 0; index < memberInfos.Length; index++)

src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderSaverCatalog.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,14 @@ public static TextLoader CreateTextLoader(this DataOperationsCatalog catalog,
6262
/// <summary>
6363
/// Create a text loader <see cref="TextLoader"/> by inferencing the dataset schema from a data model type.
6464
/// </summary>
65+
/// <typeparam name="TInput">Defines the schema of the data to be loaded. Use public fields or properties
66+
/// decorated with <see cref="LoadColumnAttribute"/> (and possibly other attributes) to specify the column
67+
/// names and their data types in the schema of the loaded data.</typeparam>
6568
/// <param name="catalog">The <see cref="DataOperationsCatalog"/> catalog.</param>
6669
/// <param name="separatorChar">Column separator character. Default is '\t'</param>
6770
/// <param name="hasHeader">Does the file contains header?</param>
68-
/// <param name="dataSample">The optional location of a data sample. The sample can be used to infer column names and number of slots in each column.</param>
71+
/// <param name="dataSample">The optional location of a data sample. The sample can be used to infer information
72+
/// about the columns, such as slot names.</param>
6973
/// <param name="allowQuoting">Whether the input may include quoted values,
7074
/// which can contain separator characters, colons,
7175
/// and distinguish empty values from missing values. When true, consecutive separators

test/Microsoft.ML.Tests/TextLoaderTests.cs

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System;
66
using System.Collections.Generic;
77
using System.IO;
8+
using System.Linq;
89
using Microsoft.ML.Data;
910
using Microsoft.ML.Model;
1011
using Microsoft.ML.RunTests;
@@ -795,5 +796,97 @@ public void TestTextLoaderKeyTypeBackCompat()
795796
Assert.True(result.Schema[featureIdx].Type is KeyDataViewType keyType && keyType.Count == typeof(uint).ToMaxInt());
796797
}
797798
}
799+
800+
private class IrisNoFields
801+
{
802+
}
803+
804+
private class IrisPrivateFields
805+
{
806+
[LoadColumn(0)]
807+
private float SepalLength;
808+
809+
[LoadColumn(1)]
810+
private float SepalWidth { get; }
811+
812+
public float GetSepalLenght()
813+
=> SepalLength;
814+
815+
public void SetSepalLength(float sepalLength)
816+
{
817+
SepalLength = sepalLength;
818+
}
819+
}
820+
private class IrisPublicGetProperties
821+
{
822+
[LoadColumn(0)]
823+
public float SepalLength { get; }
824+
825+
[LoadColumn(1)]
826+
public float SepalWidth { get; }
827+
}
828+
829+
private class IrisPublicFields
830+
{
831+
public IrisPublicFields(float sepalLength, float sepalWidth)
832+
{
833+
SepalLength = sepalLength;
834+
SepalWidth = sepalWidth;
835+
}
836+
837+
[LoadColumn(0)]
838+
public readonly float SepalLength;
839+
840+
[LoadColumn(1)]
841+
public float SepalWidth;
842+
}
843+
844+
private class IrisPublicProperties
845+
{
846+
[LoadColumn(0)]
847+
public float SepalLength { get; set; }
848+
849+
[LoadColumn(1)]
850+
public float SepalWidth { get; set; }
851+
}
852+
853+
[Fact]
854+
public void TestTextLoaderNoFields()
855+
{
856+
var dataPath = GetDataPath(TestDatasets.irisData.trainFilename);
857+
var mlContext = new MLContext();
858+
859+
// Class with get property only.
860+
var dataIris = mlContext.Data.CreateTextLoader<IrisPublicGetProperties>(separatorChar: ',').Load(dataPath);
861+
var oneIrisData = mlContext.Data.CreateEnumerable<IrisPublicProperties>(dataIris, false).First();
862+
Assert.True(oneIrisData.SepalLength != 0 && oneIrisData.SepalWidth != 0);
863+
864+
// Class with read only fields.
865+
dataIris = mlContext.Data.CreateTextLoader<IrisPublicFields>(separatorChar: ',').Load(dataPath);
866+
oneIrisData = mlContext.Data.CreateEnumerable<IrisPublicProperties>(dataIris, false).First();
867+
Assert.True(oneIrisData.SepalLength != 0 && oneIrisData.SepalWidth != 0);
868+
869+
// Class with no fields.
870+
try
871+
{
872+
dataIris = mlContext.Data.CreateTextLoader<IrisNoFields>(separatorChar: ',').Load(dataPath);
873+
Assert.False(true);
874+
}
875+
catch (Exception ex)
876+
{
877+
Assert.StartsWith("Should define at least one public, readable field or property in TInput.", ex.Message);
878+
}
879+
880+
// Class with no public readable fields.
881+
try
882+
{
883+
dataIris = mlContext.Data.CreateTextLoader<IrisPrivateFields>(separatorChar: ',').Load(dataPath);
884+
Assert.False(true);
885+
}
886+
catch (Exception ex)
887+
{
888+
Assert.StartsWith("Should define at least one public, readable field or property in TInput.", ex.Message);
889+
}
890+
}
798891
}
799892
}

0 commit comments

Comments
 (0)