diff --git a/src/Microsoft.ML.Parquet/ParquetLoader.cs b/src/Microsoft.ML.Parquet/ParquetLoader.cs
index 21271f6e5c..2def7006d2 100644
--- a/src/Microsoft.ML.Parquet/ParquetLoader.cs
+++ b/src/Microsoft.ML.Parquet/ParquetLoader.cs
@@ -12,6 +12,7 @@
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Data.IO;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Model;
using Parquet;
@@ -88,48 +89,29 @@ public sealed class Arguments
internal const string ShortName = "Parquet";
internal const string ModelSignature = "PARQELDR";
+ private const string SchemaCtxName = "Schema.idv";
+
private readonly IHost _host;
private readonly Stream _parquetStream;
private readonly ParquetOptions _parquetOptions;
private readonly int _columnChunkReadSize;
private readonly Column[] _columnsLoaded;
- private readonly DataSet _schemaDataSet;
private const int _defaultColumnChunkReadSize = 1000000;
private bool _disposed;
+ private long? _rowCount;
private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: ModelSignature,
- verWrittenCur: 0x00010001, // Initial
- verReadableCur: 0x00010001,
+ //verWrittenCur: 0x00010001, // Initial
+ verWrittenCur: 0x00010002, // Add Schema to Model Context
+ verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature);
}
- public static ParquetLoader Create(IHostEnvironment env, ModelLoadContext ctx, IMultiStreamSource files)
- {
- Contracts.CheckValue(env, nameof(env));
- IHost host = env.Register(LoaderName);
-
- env.CheckValue(ctx, nameof(ctx));
- ctx.CheckAtModel(GetVersionInfo());
- env.CheckValue(files, nameof(files));
-
- // *** Binary format ***
- // int: cached chunk size
- // bool: TreatBigIntegersAsDates flag
-
- Arguments args = new Arguments
- {
- ColumnChunkReadSize = ctx.Reader.ReadInt32(),
- TreatBigIntegersAsDates = ctx.Reader.ReadBoolean()
- };
- return host.Apply("Loading Model",
- ch => new ParquetLoader(args, host, OpenStream(files)));
- }
-
public ParquetLoader(IHostEnvironment env, Arguments args, IMultiStreamSource files)
: this(env, args, OpenStream(files))
{
@@ -165,6 +147,8 @@ private ParquetLoader(Arguments args, IHost host, Stream stream)
TreatBigIntegersAsDates = args.TreatBigIntegersAsDates
};
+ DataSet schemaDataSet;
+
try
{
// We only care about the schema so ignore the rows.
@@ -173,7 +157,8 @@ private ParquetLoader(Arguments args, IHost host, Stream stream)
Count = 0,
Offset = 0
};
- _schemaDataSet = ParquetReader.Read(stream, _parquetOptions, readerOptions);
+ schemaDataSet = ParquetReader.Read(stream, _parquetOptions, readerOptions);
+ _rowCount = schemaDataSet.TotalRowCount;
}
catch (Exception ex)
{
@@ -181,28 +166,103 @@ private ParquetLoader(Arguments args, IHost host, Stream stream)
}
_columnChunkReadSize = args.ColumnChunkReadSize;
- InitColumns(ch, out _columnsLoaded);
+ _columnsLoaded = InitColumns(schemaDataSet);
Schema = CreateSchema(_host, _columnsLoaded);
}
}
+ private ParquetLoader(IHost host, ModelLoadContext ctx, IMultiStreamSource files)
+ {
+ Contracts.AssertValue(host);
+ _host = host;
+ _host.AssertValue(ctx);
+ _host.AssertValue(files);
+
+ // *** Binary format ***
+ // int: cached chunk size
+ // bool: TreatBigIntegersAsDates flag
+ // Schema of the loader (0x00010002)
+
+ _columnChunkReadSize = ctx.Reader.ReadInt32();
+ bool treatBigIntegersAsDates = ctx.Reader.ReadBoolean();
+
+ if (ctx.Header.ModelVerWritten >= 0x00010002)
+ {
+ // Load the schema
+ byte[] buffer = null;
+ if (!ctx.TryLoadBinaryStream(SchemaCtxName, r => buffer = r.ReadByteArray()))
+ throw _host.ExceptDecode();
+ var strm = new MemoryStream(buffer, writable: false);
+ var loader = new BinaryLoader(_host, new BinaryLoader.Arguments(), strm);
+ Schema = loader.Schema;
+ }
+
+ // Only load Parquest related data if a file is present. Otherwise, just the Schema is valid.
+ if (files.Count > 0)
+ {
+ _parquetOptions = new ParquetOptions()
+ {
+ TreatByteArrayAsString = true,
+ TreatBigIntegersAsDates = treatBigIntegersAsDates
+ };
+
+ _parquetStream = OpenStream(files);
+ DataSet schemaDataSet;
+
+ try
+ {
+ // We only care about the schema so ignore the rows.
+ ReaderOptions readerOptions = new ReaderOptions()
+ {
+ Count = 0,
+ Offset = 0
+ };
+ schemaDataSet = ParquetReader.Read(_parquetStream, _parquetOptions, readerOptions);
+ _rowCount = schemaDataSet.TotalRowCount;
+ }
+ catch (Exception ex)
+ {
+ throw new InvalidDataException("Cannot read Parquet file", ex);
+ }
+
+ _columnsLoaded = InitColumns(schemaDataSet);
+ Schema = CreateSchema(_host, _columnsLoaded);
+ }
+ else if (Schema == null)
+ {
+ throw _host.Except("Parquet loader must be created with one file");
+ }
+ }
+
+ public static ParquetLoader Create(IHostEnvironment env, ModelLoadContext ctx, IMultiStreamSource files)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ IHost host = env.Register(LoaderName);
+
+ env.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel(GetVersionInfo());
+ env.CheckValue(files, nameof(files));
+
+ return host.Apply("Loading Model",
+ ch => new ParquetLoader(host, ctx, files));
+ }
+
///
/// Helper function called by the ParquetLoader constructor to initialize the Columns that belong in the Parquet file.
/// Composite data fields are flattened; for example, a Map Field in Parquet is flattened into a Key column and a Value
/// column.
///
- /// Communication channel for error reporting.
- /// The array of flattened columns instantiated from the parquet file.
- private void InitColumns(IChannel ch, out Column[] cols)
+ /// The schema data set.
+ /// The array of flattened columns instantiated from the parquet file.
+ private Column[] InitColumns(DataSet dataSet)
{
- cols = null;
List columnsLoaded = new List();
- foreach (var parquetField in _schemaDataSet.Schema.Fields)
+ foreach (var parquetField in dataSet.Schema.Fields)
{
FlattenFields(parquetField, ref columnsLoaded, false);
}
- cols = columnsLoaded.ToArray();
+ return columnsLoaded.ToArray();
}
private void FlattenFields(Field field, ref List cols, bool isRepeatable)
@@ -239,7 +299,7 @@ private void FlattenFields(Field field, ref List cols, bool isRepeatable
}
else
{
- throw new InvalidDataException("Encountered unknown Parquet field type(Currently recognizes data, map, list, and struct).");
+ throw _host.ExceptNotSupp("Encountered unknown Parquet field type(Currently recognizes data, map, list, and struct).");
}
}
@@ -326,7 +386,7 @@ private static Stream OpenStream(string filename)
public long? GetRowCount(bool lazy = true)
{
- return _schemaDataSet.TotalRowCount;
+ return _rowCount;
}
public IRowCursor GetRowCursor(Func predicate, IRandom rand = null)
@@ -353,9 +413,22 @@ public void Save(ModelSaveContext ctx)
// *** Binary format ***
// int: cached chunk size
// bool: TreatBigIntegersAsDates flag
+ // Schema of the loader
ctx.Writer.Write(_columnChunkReadSize);
ctx.Writer.Write(_parquetOptions.TreatBigIntegersAsDates);
+
+ // Save the schema
+ var noRows = new EmptyDataView(_host, Schema);
+ var saverArgs = new BinarySaver.Arguments();
+ saverArgs.Silent = true;
+ var saver = new BinarySaver(_host, saverArgs);
+ using (var strm = new MemoryStream())
+ {
+ var allColumns = Enumerable.Range(0, Schema.ColumnCount).ToArray();
+ saver.SaveData(strm, noRows, allColumns);
+ ctx.SaveBinaryStream(SchemaCtxName, w => w.WriteByteArray(strm.ToArray()));
+ }
}
private sealed class Cursor : RootCursorBase, IRowCursor
@@ -377,6 +450,8 @@ public Cursor(ParquetLoader parent, Func predicate, IRandom rand)
: base(parent._host)
{
Ch.AssertValue(predicate);
+ Ch.AssertValue(parent._parquetStream);
+
_loader = parent;
_fileStream = parent._parquetStream;
_parquetConversions = new ParquetConversions(Ch);