Skip to content

Commit c6eb2f7

Browse files
author
Prashanth Govindarajan
authored
Use dataTypes if it passed in to LoadCsv (dotnet#2791)
* Fix LoadCsv to use dataType if it passed in * sq * Don't read the full file after guessRows lines have been read * Address feedback * Last round of feedback
1 parent 81f3d42 commit c6eb2f7

File tree

4 files changed

+182
-100
lines changed

4 files changed

+182
-100
lines changed

src/Microsoft.Data.Analysis/DataFrame.IO.cs

+107-99
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
using System;
66
using System.Collections.Generic;
77
using System.IO;
8-
using System.Text;
98

109
namespace Microsoft.Data.Analysis
1110
{
@@ -104,6 +103,77 @@ public static DataFrame LoadCsv(string filename,
104103
}
105104
}
106105

106+
private static DataFrameColumn CreateColumn(Type kind, string[] columnNames, int columnIndex)
107+
{
108+
PrimitiveDataFrameColumn<T> CreatePrimitiveDataFrameColumn<T>()
109+
where T : unmanaged
110+
{
111+
return new PrimitiveDataFrameColumn<T>(columnNames == null ? "Column" + columnIndex.ToString() : columnNames[columnIndex]);
112+
}
113+
DataFrameColumn ret;
114+
if (kind == typeof(bool))
115+
{
116+
ret = CreatePrimitiveDataFrameColumn<bool>();
117+
}
118+
else if (kind == typeof(int))
119+
{
120+
ret = CreatePrimitiveDataFrameColumn<int>();
121+
}
122+
else if (kind == typeof(float))
123+
{
124+
ret = CreatePrimitiveDataFrameColumn<float>();
125+
}
126+
else if (kind == typeof(string))
127+
{
128+
ret = new StringDataFrameColumn(columnNames == null ? "Column" + columnIndex.ToString() : columnNames[columnIndex], 0);
129+
}
130+
else if (kind == typeof(long))
131+
{
132+
ret = CreatePrimitiveDataFrameColumn<long>();
133+
}
134+
else if (kind == typeof(decimal))
135+
{
136+
ret = CreatePrimitiveDataFrameColumn<decimal>();
137+
}
138+
else if (kind == typeof(byte))
139+
{
140+
ret = CreatePrimitiveDataFrameColumn<byte>();
141+
}
142+
else if (kind == typeof(char))
143+
{
144+
ret = CreatePrimitiveDataFrameColumn<char>();
145+
}
146+
else if (kind == typeof(double))
147+
{
148+
ret = CreatePrimitiveDataFrameColumn<double>();
149+
}
150+
else if (kind == typeof(sbyte))
151+
{
152+
ret = CreatePrimitiveDataFrameColumn<sbyte>();
153+
}
154+
else if (kind == typeof(short))
155+
{
156+
ret = CreatePrimitiveDataFrameColumn<short>();
157+
}
158+
else if (kind == typeof(uint))
159+
{
160+
ret = CreatePrimitiveDataFrameColumn<uint>();
161+
}
162+
else if (kind == typeof(ulong))
163+
{
164+
ret = CreatePrimitiveDataFrameColumn<ulong>();
165+
}
166+
else if (kind == typeof(ushort))
167+
{
168+
ret = CreatePrimitiveDataFrameColumn<ushort>();
169+
}
170+
else
171+
{
172+
throw new NotSupportedException(nameof(kind));
173+
}
174+
return ret;
175+
}
176+
107177
/// <summary>
108178
/// Reads a seekable stream of CSV data into a DataFrame.
109179
/// Follows pandas API.
@@ -116,7 +186,7 @@ public static DataFrame LoadCsv(string filename,
116186
/// <param name="numberOfRowsToRead">number of rows to read not including the header(if present)</param>
117187
/// <param name="guessRows">number of rows used to guess types</param>
118188
/// <param name="addIndexColumn">add one column with the row index</param>
119-
/// <returns>DataFrame</returns>
189+
/// <returns><see cref="DataFrame"/></returns>
120190
public static DataFrame LoadCsv(Stream csvStream,
121191
char separator = ',', bool header = true,
122192
string[] columnNames = null, Type[] dataTypes = null,
@@ -127,7 +197,7 @@ public static DataFrame LoadCsv(Stream csvStream,
127197

128198
var linesForGuessType = new List<string[]>();
129199
long rowline = 0;
130-
int numberOfColumns = 0;
200+
int numberOfColumns = dataTypes?.Length ?? 0;
131201

132202
if (header == true && numberOfRowsToRead != -1)
133203
numberOfRowsToRead++;
@@ -137,60 +207,52 @@ public static DataFrame LoadCsv(Stream csvStream,
137207
// First pass: schema and number of rows.
138208
using (var streamReader = new StreamReader(csvStream, encoding: null, detectEncodingFromByteOrderMarks: true, bufferSize: -1, leaveOpen: true))
139209
{
140-
string line = streamReader.ReadLine();
141-
while (line != null)
210+
string line = null;
211+
if (dataTypes == null)
142212
{
143-
if ((numberOfRowsToRead == -1) || rowline < numberOfRowsToRead)
213+
line = streamReader.ReadLine();
214+
while (line != null)
144215
{
145-
if (linesForGuessType.Count < guessRows)
216+
if ((numberOfRowsToRead == -1) || rowline < numberOfRowsToRead)
146217
{
147-
var spl = line.Split(separator);
148-
if (header && rowline == 0)
218+
if (linesForGuessType.Count < guessRows)
149219
{
150-
if (columnNames == null)
151-
columnNames = spl;
152-
}
153-
else
154-
{
155-
linesForGuessType.Add(spl);
156-
numberOfColumns = Math.Max(numberOfColumns, spl.Length);
220+
var spl = line.Split(separator);
221+
if (header && rowline == 0)
222+
{
223+
if (columnNames == null)
224+
columnNames = spl;
225+
}
226+
else
227+
{
228+
linesForGuessType.Add(spl);
229+
numberOfColumns = Math.Max(numberOfColumns, spl.Length);
230+
}
157231
}
158232
}
233+
++rowline;
234+
if (rowline == guessRows)
235+
{
236+
break;
237+
}
238+
line = streamReader.ReadLine();
159239
}
160-
++rowline;
161-
if (rowline == numberOfRowsToRead)
162-
break;
163-
line = streamReader.ReadLine();
164-
}
165240

166-
if (linesForGuessType.Count == 0)
167-
throw new FormatException(Strings.EmptyFile);
241+
if (linesForGuessType.Count == 0)
242+
{
243+
throw new FormatException(Strings.EmptyFile);
244+
}
245+
}
168246

169247
columns = new List<DataFrameColumn>(numberOfColumns);
170-
171-
// Guesses types and adds columns.
248+
// Guesses types or looks up dataTypes and adds columns.
172249
for (int i = 0; i < numberOfColumns; ++i)
173250
{
174-
Type kind = GuessKind(i, linesForGuessType);
175-
if (kind == typeof(bool))
176-
{
177-
DataFrameColumn boolColumn = new PrimitiveDataFrameColumn<bool>(columnNames == null ? "Column" + i.ToString() : columnNames[i], header == true ? rowline - 1 : rowline);
178-
columns.Add(boolColumn);
179-
}
180-
else if (kind == typeof(float))
181-
{
182-
DataFrameColumn floatColumn = new PrimitiveDataFrameColumn<float>(columnNames == null ? "Column" + i.ToString() : columnNames[i], header == true ? rowline - 1 : rowline);
183-
columns.Add(floatColumn);
184-
}
185-
else if (kind == typeof(string))
186-
{
187-
DataFrameColumn stringColumn = new StringDataFrameColumn(columnNames == null ? "Column" + i.ToString() : columnNames[i], header == true ? rowline - 1 : rowline);
188-
columns.Add(stringColumn);
189-
}
190-
else
191-
throw new NotSupportedException(nameof(kind));
251+
Type kind = dataTypes == null ? GuessKind(i, linesForGuessType) : dataTypes[i];
252+
columns.Add(CreateColumn(kind, columnNames, i));
192253
}
193254

255+
DataFrame ret = new DataFrame(columns);
194256
line = null;
195257
streamReader.DiscardBufferedData();
196258
streamReader.BaseStream.Seek(streamStart, SeekOrigin.Begin);
@@ -207,7 +269,7 @@ public static DataFrame LoadCsv(Stream csvStream,
207269
}
208270
else
209271
{
210-
AppendRow(columns, header == true ? rowline - 1 : rowline, spl);
272+
ret.Append(spl);
211273
}
212274
++rowline;
213275
line = streamReader.ReadLine();
@@ -222,61 +284,7 @@ public static DataFrame LoadCsv(Stream csvStream,
222284
}
223285
columns.Insert(0, indexColumn);
224286
}
225-
}
226-
return new DataFrame(columns);
227-
}
228-
229-
private static void AppendRow(List<DataFrameColumn> columns, long rowIndex, string[] values)
230-
{
231-
for (int i = 0; i < columns.Count; i++)
232-
{
233-
DataFrameColumn column = columns[i];
234-
string val = values[i];
235-
Type dType = column.DataType;
236-
if (dType == typeof(bool))
237-
{
238-
bool boolParse = bool.TryParse(val, out bool boolResult);
239-
if (boolParse)
240-
{
241-
column[rowIndex] = boolResult;
242-
continue;
243-
}
244-
else
245-
{
246-
if (string.IsNullOrEmpty(val))
247-
{
248-
column[rowIndex] = null;
249-
continue;
250-
}
251-
throw new ArgumentException(string.Format(Strings.MismatchedValueType, typeof(bool)), nameof(val));
252-
}
253-
}
254-
else if (dType == typeof(float))
255-
{
256-
bool floatParse = float.TryParse(val, out float floatResult);
257-
if (floatParse)
258-
{
259-
column[rowIndex] = floatResult;
260-
continue;
261-
}
262-
else
263-
{
264-
if (string.IsNullOrEmpty(val))
265-
{
266-
column[rowIndex] = null;
267-
continue;
268-
}
269-
throw new ArgumentException(string.Format(Strings.MismatchedValueType, typeof(float)), nameof(val));
270-
}
271-
}
272-
else if (dType == typeof(string))
273-
{
274-
column[rowIndex] = values[i];
275-
}
276-
else
277-
{
278-
throw new NotImplementedException();
279-
}
287+
return ret;
280288
}
281289
}
282290
}

src/Microsoft.Data.Analysis/DataFrame.cs

+5
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,11 @@ public void Append(IEnumerable<object> row = null)
445445
{
446446
DataFrameColumn column = columnEnumerator.Current;
447447
object value = rowEnumerator.Current;
448+
// StringDataFrameColumn can accept empty strings. The other columns interpret empty values as nulls
449+
if (value is string stringValue && string.IsNullOrEmpty(stringValue) && column.DataType != typeof(string))
450+
{
451+
value = null;
452+
}
448453
if (value != null)
449454
{
450455
value = Convert.ChangeType(value, column.DataType);

tests/Microsoft.Data.Analysis.Tests/DataFrame.IOTests.cs

+47-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
using System;
66
using System.IO;
7-
using System.Runtime.CompilerServices;
87
using System.Text;
98
using Xunit;
109

@@ -58,5 +57,52 @@ Stream GetStream(string streamData)
5857
Assert.Equal(7, reducedRows.Columns.Count);
5958
Assert.Equal("CMT", reducedRows["Column0"][2]);
6059
}
60+
61+
[Fact]
62+
public void TestReadCsvWithTypes()
63+
{
64+
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,trip_distance,payment_type,fare_amount
65+
CMT,1,1,1271,3.8,CRD,17.5
66+
CMT,1,1,474,1.5,CRD,8
67+
CMT,1,1,637,1.4,CRD,8.5
68+
,,,,,,
69+
CMT,1,1,181,0.6,CSH,4.5";
70+
71+
Stream GetStream(string streamData)
72+
{
73+
return new MemoryStream(Encoding.Default.GetBytes(streamData));
74+
}
75+
DataFrame df = DataFrame.LoadCsv(GetStream(data), dataTypes: new Type[] { typeof(string), typeof(short), typeof(int), typeof(long), typeof(float), typeof(string), typeof(double) });
76+
Assert.Equal(5, df.RowCount);
77+
Assert.Equal(7, df.Columns.Count);
78+
79+
Assert.True(typeof(string) == df.Columns[0].DataType);
80+
Assert.True(typeof(short) == df.Columns[1].DataType);
81+
Assert.True(typeof(int) == df.Columns[2].DataType);
82+
Assert.True(typeof(long) == df.Columns[3].DataType);
83+
Assert.True(typeof(float) == df.Columns[4].DataType);
84+
Assert.True(typeof(string) == df.Columns[5].DataType);
85+
Assert.True(typeof(double) == df.Columns[6].DataType);
86+
87+
foreach (var column in df.Columns)
88+
{
89+
if (column.DataType != typeof(string))
90+
{
91+
Assert.Equal(1, column.NullCount);
92+
}
93+
else
94+
{
95+
Assert.Equal(0, column.NullCount);
96+
}
97+
}
98+
var nullRow = df[3];
99+
Assert.Equal("", nullRow[0]);
100+
Assert.Null(nullRow[1]);
101+
Assert.Null(nullRow[2]);
102+
Assert.Null(nullRow[3]);
103+
Assert.Null(nullRow[4]);
104+
Assert.Equal("", nullRow[5]);
105+
Assert.Null(nullRow[6]);
106+
}
61107
}
62108
}

tests/Microsoft.Data.Analysis.Tests/DataFrameTests.cs

+23
Original file line numberDiff line numberDiff line change
@@ -1926,5 +1926,28 @@ public void TestAppendRow()
19261926
Assert.Equal(5, df.Columns[0].NullCount);
19271927
Assert.Equal(6, df.Columns[1].NullCount);
19281928
}
1929+
1930+
[Fact]
1931+
public void TestAppendEmptyValue()
1932+
{
1933+
DataFrame df = MakeDataFrame<int, bool>(10);
1934+
df.Append(new List<object> { "", true });
1935+
Assert.Equal(11, df.RowCount);
1936+
Assert.Equal(2, df.Columns[0].NullCount);
1937+
Assert.Equal(1, df.Columns[1].NullCount);
1938+
1939+
StringDataFrameColumn column = new StringDataFrameColumn("Strings", Enumerable.Range(0, 11).Select(x => x.ToString()));
1940+
df.Columns.Add(column);
1941+
1942+
df.Append(new List<object> { 1, true, "" });
1943+
Assert.Equal(12, df.RowCount);
1944+
Assert.Equal(2, df.Columns[0].NullCount);
1945+
Assert.Equal(1, df.Columns[1].NullCount);
1946+
Assert.Equal(0, df.Columns[2].NullCount);
1947+
1948+
df.Append(new List<object> { 1, true, null });
1949+
Assert.Equal(13, df.RowCount);
1950+
Assert.Equal(1, df.Columns[2].NullCount);
1951+
}
19291952
}
19301953
}

0 commit comments

Comments
 (0)