diff --git a/src/Microsoft.ML.Parquet/ParquetLoader.cs b/src/Microsoft.ML.Parquet/ParquetLoader.cs index 21271f6e5c..2def7006d2 100644 --- a/src/Microsoft.ML.Parquet/ParquetLoader.cs +++ b/src/Microsoft.ML.Parquet/ParquetLoader.cs @@ -12,6 +12,7 @@ using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; using Parquet; @@ -88,48 +89,29 @@ public sealed class Arguments internal const string ShortName = "Parquet"; internal const string ModelSignature = "PARQELDR"; + private const string SchemaCtxName = "Schema.idv"; + private readonly IHost _host; private readonly Stream _parquetStream; private readonly ParquetOptions _parquetOptions; private readonly int _columnChunkReadSize; private readonly Column[] _columnsLoaded; - private readonly DataSet _schemaDataSet; private const int _defaultColumnChunkReadSize = 1000000; private bool _disposed; + private long? _rowCount; private static VersionInfo GetVersionInfo() { return new VersionInfo( modelSignature: ModelSignature, - verWrittenCur: 0x00010001, // Initial - verReadableCur: 0x00010001, + //verWrittenCur: 0x00010001, // Initial + verWrittenCur: 0x00010002, // Add Schema to Model Context + verReadableCur: 0x00010002, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature); } - public static ParquetLoader Create(IHostEnvironment env, ModelLoadContext ctx, IMultiStreamSource files) - { - Contracts.CheckValue(env, nameof(env)); - IHost host = env.Register(LoaderName); - - env.CheckValue(ctx, nameof(ctx)); - ctx.CheckAtModel(GetVersionInfo()); - env.CheckValue(files, nameof(files)); - - // *** Binary format *** - // int: cached chunk size - // bool: TreatBigIntegersAsDates flag - - Arguments args = new Arguments - { - ColumnChunkReadSize = ctx.Reader.ReadInt32(), - TreatBigIntegersAsDates = ctx.Reader.ReadBoolean() - }; - return host.Apply("Loading Model", - ch => new ParquetLoader(args, host, OpenStream(files))); - } - public ParquetLoader(IHostEnvironment env, Arguments args, IMultiStreamSource files) : this(env, args, OpenStream(files)) { @@ -165,6 +147,8 @@ private ParquetLoader(Arguments args, IHost host, Stream stream) TreatBigIntegersAsDates = args.TreatBigIntegersAsDates }; + DataSet schemaDataSet; + try { // We only care about the schema so ignore the rows. @@ -173,7 +157,8 @@ private ParquetLoader(Arguments args, IHost host, Stream stream) Count = 0, Offset = 0 }; - _schemaDataSet = ParquetReader.Read(stream, _parquetOptions, readerOptions); + schemaDataSet = ParquetReader.Read(stream, _parquetOptions, readerOptions); + _rowCount = schemaDataSet.TotalRowCount; } catch (Exception ex) { @@ -181,28 +166,103 @@ private ParquetLoader(Arguments args, IHost host, Stream stream) } _columnChunkReadSize = args.ColumnChunkReadSize; - InitColumns(ch, out _columnsLoaded); + _columnsLoaded = InitColumns(schemaDataSet); Schema = CreateSchema(_host, _columnsLoaded); } } + private ParquetLoader(IHost host, ModelLoadContext ctx, IMultiStreamSource files) + { + Contracts.AssertValue(host); + _host = host; + _host.AssertValue(ctx); + _host.AssertValue(files); + + // *** Binary format *** + // int: cached chunk size + // bool: TreatBigIntegersAsDates flag + // Schema of the loader (0x00010002) + + _columnChunkReadSize = ctx.Reader.ReadInt32(); + bool treatBigIntegersAsDates = ctx.Reader.ReadBoolean(); + + if (ctx.Header.ModelVerWritten >= 0x00010002) + { + // Load the schema + byte[] buffer = null; + if (!ctx.TryLoadBinaryStream(SchemaCtxName, r => buffer = r.ReadByteArray())) + throw _host.ExceptDecode(); + var strm = new MemoryStream(buffer, writable: false); + var loader = new BinaryLoader(_host, new BinaryLoader.Arguments(), strm); + Schema = loader.Schema; + } + + // Only load Parquest related data if a file is present. Otherwise, just the Schema is valid. + if (files.Count > 0) + { + _parquetOptions = new ParquetOptions() + { + TreatByteArrayAsString = true, + TreatBigIntegersAsDates = treatBigIntegersAsDates + }; + + _parquetStream = OpenStream(files); + DataSet schemaDataSet; + + try + { + // We only care about the schema so ignore the rows. + ReaderOptions readerOptions = new ReaderOptions() + { + Count = 0, + Offset = 0 + }; + schemaDataSet = ParquetReader.Read(_parquetStream, _parquetOptions, readerOptions); + _rowCount = schemaDataSet.TotalRowCount; + } + catch (Exception ex) + { + throw new InvalidDataException("Cannot read Parquet file", ex); + } + + _columnsLoaded = InitColumns(schemaDataSet); + Schema = CreateSchema(_host, _columnsLoaded); + } + else if (Schema == null) + { + throw _host.Except("Parquet loader must be created with one file"); + } + } + + public static ParquetLoader Create(IHostEnvironment env, ModelLoadContext ctx, IMultiStreamSource files) + { + Contracts.CheckValue(env, nameof(env)); + IHost host = env.Register(LoaderName); + + env.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); + env.CheckValue(files, nameof(files)); + + return host.Apply("Loading Model", + ch => new ParquetLoader(host, ctx, files)); + } + /// /// Helper function called by the ParquetLoader constructor to initialize the Columns that belong in the Parquet file. /// Composite data fields are flattened; for example, a Map Field in Parquet is flattened into a Key column and a Value /// column. /// - /// Communication channel for error reporting. - /// The array of flattened columns instantiated from the parquet file. - private void InitColumns(IChannel ch, out Column[] cols) + /// The schema data set. + /// The array of flattened columns instantiated from the parquet file. + private Column[] InitColumns(DataSet dataSet) { - cols = null; List columnsLoaded = new List(); - foreach (var parquetField in _schemaDataSet.Schema.Fields) + foreach (var parquetField in dataSet.Schema.Fields) { FlattenFields(parquetField, ref columnsLoaded, false); } - cols = columnsLoaded.ToArray(); + return columnsLoaded.ToArray(); } private void FlattenFields(Field field, ref List cols, bool isRepeatable) @@ -239,7 +299,7 @@ private void FlattenFields(Field field, ref List cols, bool isRepeatable } else { - throw new InvalidDataException("Encountered unknown Parquet field type(Currently recognizes data, map, list, and struct)."); + throw _host.ExceptNotSupp("Encountered unknown Parquet field type(Currently recognizes data, map, list, and struct)."); } } @@ -326,7 +386,7 @@ private static Stream OpenStream(string filename) public long? GetRowCount(bool lazy = true) { - return _schemaDataSet.TotalRowCount; + return _rowCount; } public IRowCursor GetRowCursor(Func predicate, IRandom rand = null) @@ -353,9 +413,22 @@ public void Save(ModelSaveContext ctx) // *** Binary format *** // int: cached chunk size // bool: TreatBigIntegersAsDates flag + // Schema of the loader ctx.Writer.Write(_columnChunkReadSize); ctx.Writer.Write(_parquetOptions.TreatBigIntegersAsDates); + + // Save the schema + var noRows = new EmptyDataView(_host, Schema); + var saverArgs = new BinarySaver.Arguments(); + saverArgs.Silent = true; + var saver = new BinarySaver(_host, saverArgs); + using (var strm = new MemoryStream()) + { + var allColumns = Enumerable.Range(0, Schema.ColumnCount).ToArray(); + saver.SaveData(strm, noRows, allColumns); + ctx.SaveBinaryStream(SchemaCtxName, w => w.WriteByteArray(strm.ToArray())); + } } private sealed class Cursor : RootCursorBase, IRowCursor @@ -377,6 +450,8 @@ public Cursor(ParquetLoader parent, Func predicate, IRandom rand) : base(parent._host) { Ch.AssertValue(predicate); + Ch.AssertValue(parent._parquetStream); + _loader = parent; _fileStream = parent._parquetStream; _parquetConversions = new ParquetConversions(Ch);