diff --git a/package.json b/package.json index f49c2422..3ca1385e 100644 --- a/package.json +++ b/package.json @@ -71,6 +71,7 @@ "@types/benchmark": "2.1.5", "@types/glob": "8.1.0", "@types/jest": "29.5.14", + "@types/lz4js": "0.2.1", "@types/multistream": "4.1.3", "async-done": "2.0.0", "benny": "3.7.1", @@ -96,6 +97,7 @@ "ix": "7.0.0", "jest": "29.7.0", "jest-silent-reporter": "0.6.0", + "lz4js": "0.2.0", "memfs": "4.17.2", "mkdirp": "3.0.1", "multistream": "4.1.0", @@ -114,7 +116,8 @@ "webpack": "5.99.9", "webpack-bundle-analyzer": "4.10.2", "webpack-stream": "7.0.0", - "xml2js": "0.6.2" + "xml2js": "0.6.2", + "zstd-codec": "0.1.5" }, "engines": { "node": ">=12.0" diff --git a/src/Arrow.dom.ts b/src/Arrow.dom.ts index b6e3fbce..eb03bf2b 100644 --- a/src/Arrow.dom.ts +++ b/src/Arrow.dom.ts @@ -76,6 +76,7 @@ export { RecordBatch, util, Builder, makeBuilder, builderThroughIterable, builderThroughAsyncIterable, + compressionRegistry, CompressionType, } from './Arrow.js'; export { diff --git a/src/Arrow.ts b/src/Arrow.ts index f31f91a7..ac6fe3d8 100644 --- a/src/Arrow.ts +++ b/src/Arrow.ts @@ -16,6 +16,7 @@ // under the License. export { MessageHeader } from './fb/message-header.js'; +export { CompressionType } from './fb/compression-type.js'; export { Type, @@ -92,6 +93,7 @@ export type { ReadableSource, WritableSink } from './io/stream.js'; export { RecordBatchReader, RecordBatchFileReader, RecordBatchStreamReader, AsyncRecordBatchFileReader, AsyncRecordBatchStreamReader } from './ipc/reader.js'; export { RecordBatchWriter, RecordBatchFileWriter, RecordBatchStreamWriter, RecordBatchJSONWriter } from './ipc/writer.js'; export { tableToIPC, tableFromIPC } from './ipc/serialization.js'; +export { compressionRegistry } from './ipc/compression/registry.js'; export { MessageReader, AsyncMessageReader, JSONMessageReader } from './ipc/message.js'; export { Message } from './ipc/metadata/message.js'; export { RecordBatch } from './recordbatch.js'; diff --git a/src/ipc/compression/constants.ts b/src/ipc/compression/constants.ts new file mode 100644 index 00000000..bf407543 --- /dev/null +++ b/src/ipc/compression/constants.ts @@ -0,0 +1,19 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +export const LENGTH_NO_COMPRESSED_DATA = -1; +export const COMPRESS_LENGTH_PREFIX = 8; diff --git a/src/ipc/compression/registry.ts b/src/ipc/compression/registry.ts new file mode 100644 index 00000000..af7d819c --- /dev/null +++ b/src/ipc/compression/registry.ts @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import { CompressionType } from '../../fb/compression-type.js'; +import { compressionValidators } from './validators.js'; + +export interface Codec { + encode?(data: Uint8Array): Uint8Array; + decode?(data: Uint8Array): Uint8Array; +} + +class _CompressionRegistry { + protected declare registry: { [key in CompressionType]?: Codec }; + + constructor() { + this.registry = {}; + } + + set(compression: CompressionType, codec: Codec) { + if (codec?.encode && typeof codec.encode === 'function' && !compressionValidators[compression].isValidCodecEncode(codec)) { + throw new Error(`Encoder for ${CompressionType[compression]} is not valid.`); + } + this.registry[compression] = codec; + } + + get(compression: CompressionType): Codec | null { + return this.registry?.[compression] || null; + } + +} + +export const compressionRegistry = new _CompressionRegistry(); diff --git a/src/ipc/compression/validators.ts b/src/ipc/compression/validators.ts new file mode 100644 index 00000000..a0b04073 --- /dev/null +++ b/src/ipc/compression/validators.ts @@ -0,0 +1,92 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import type { Codec } from './registry.ts'; +import { CompressionType } from '../../fb/compression-type.js'; + +export interface CompressionValidator { + isValidCodecEncode(codec: Codec): boolean; +} + +class Lz4FrameValidator implements CompressionValidator { + private readonly LZ4_FRAME_MAGIC = new Uint8Array([4, 34, 77, 24]); + private readonly MIN_HEADER_LENGTH = 7; // 4 (magic) + 2 (FLG + BD) + 1 (header checksum) = 7 min bytes + + isValidCodecEncode(codec: Codec): boolean { + const testData = new Uint8Array([1, 2, 3, 4, 5, 6, 7, 8]); + const compressed = codec.encode!(testData); + return this._isValidCompressed(compressed); + } + + private _isValidCompressed(buffer: Uint8Array): boolean { + return ( + this._hasMinimumLength(buffer) && + this._hasValidMagicNumber(buffer) && + this._hasValidVersion(buffer) + ); + } + + private _hasMinimumLength(buffer: Uint8Array): boolean { + return buffer.length >= this.MIN_HEADER_LENGTH; + } + + private _hasValidMagicNumber(buffer: Uint8Array): boolean { + return this.LZ4_FRAME_MAGIC.every( + (byte, i) => buffer[i] === byte + ); + } + + private _hasValidVersion(buffer: Uint8Array): boolean { + const flg = buffer[4]; + const versionBits = (flg & 0xC0) >> 6; + return versionBits === 1; + } + +} + +class ZstdValidator implements CompressionValidator { + private readonly ZSTD_MAGIC = new Uint8Array([40, 181, 47, 253]); + private readonly MIN_HEADER_LENGTH = 6; // 4 (magic) + 2 (min Frame_Header) = 6 min bytes + + isValidCodecEncode(codec: Codec): boolean { + const testData = new Uint8Array([1, 2, 3, 4, 5, 6, 7, 8]); + const compressed = codec.encode!(testData); + return this._isValidCompressed(compressed); + } + + private _isValidCompressed(buffer: Uint8Array): boolean { + return ( + this._hasMinimumLength(buffer) && + this._hasValidMagicNumber(buffer) + ); + } + + private _hasMinimumLength(buffer: Uint8Array): boolean { + return buffer.length >= this.MIN_HEADER_LENGTH; + } + + private _hasValidMagicNumber(buffer: Uint8Array): boolean { + return this.ZSTD_MAGIC.every( + (byte, i) => buffer[i] === byte + ); + } +} + +export const compressionValidators: Record = { + [CompressionType.LZ4_FRAME]: new Lz4FrameValidator(), + [CompressionType.ZSTD]: new ZstdValidator(), +}; diff --git a/src/ipc/metadata/json.ts b/src/ipc/metadata/json.ts index bb88f0da..15f87189 100644 --- a/src/ipc/metadata/json.ts +++ b/src/ipc/metadata/json.ts @@ -40,7 +40,8 @@ export function recordBatchFromJSON(b: any) { return new RecordBatch( b['count'], fieldNodesFromJSON(b['columns']), - buffersFromJSON(b['columns']) + buffersFromJSON(b['columns']), + null ); } diff --git a/src/ipc/metadata/message.ts b/src/ipc/metadata/message.ts index d3428972..17e8897b 100644 --- a/src/ipc/metadata/message.ts +++ b/src/ipc/metadata/message.ts @@ -40,6 +40,9 @@ import { FixedSizeBinary as _FixedSizeBinary } from '../../fb/fixed-size-binary. import { FixedSizeList as _FixedSizeList } from '../../fb/fixed-size-list.js'; import { Map as _Map } from '../../fb/map.js'; import { Message as _Message } from '../../fb/message.js'; +import { CompressionType as _CompressionType } from '../../fb/compression-type.js'; +import { BodyCompression as _BodyCompression } from '../../fb/body-compression.js'; +import { BodyCompressionMethod as _BodyCompressionMethod } from '../../fb/body-compression-method.js'; import { Schema, Field } from '../../schema.js'; import { toUint8Array } from '../../util/buffer.js'; @@ -122,9 +125,11 @@ export class Message { protected _headerType: T; protected _bodyLength: number; protected _version: MetadataVersion; + protected _compression: BodyCompression | null; public get type() { return this.headerType; } public get version() { return this._version; } public get headerType() { return this._headerType; } + public get compression() { return this._compression; } public get bodyLength() { return this._bodyLength; } declare protected _createHeader: MessageHeaderDecoder; public header() { return this._createHeader(); } @@ -136,6 +141,7 @@ export class Message { this._version = version; this._headerType = headerType; this.body = new Uint8Array(0); + this._compression = header?.compression; header && (this._createHeader = () => header); this._bodyLength = bigIntToNumber(bodyLength); } @@ -149,13 +155,21 @@ export class RecordBatch { protected _length: number; protected _nodes: FieldNode[]; protected _buffers: BufferRegion[]; + protected _compression: BodyCompression | null; public get nodes() { return this._nodes; } public get length() { return this._length; } public get buffers() { return this._buffers; } - constructor(length: bigint | number, nodes: FieldNode[], buffers: BufferRegion[]) { + public get compression() { return this._compression; } + constructor( + length: bigint | number, + nodes: FieldNode[], + buffers: BufferRegion[], + compression: BodyCompression | null + ) { this._nodes = nodes; this._buffers = buffers; this._length = bigIntToNumber(length); + this._compression = compression; } } @@ -208,6 +222,19 @@ export class FieldNode { } } +/** + * @ignore + * @private + **/ +export class BodyCompression { + public type: _CompressionType; + public method: _BodyCompressionMethod; + constructor(type: _CompressionType, method: _BodyCompressionMethod = _BodyCompressionMethod.BUFFER) { + this.type = type; + this.method = method; + } +} + /** @ignore */ function messageHeaderFromJSON(message: any, type: MessageHeader) { return (() => { @@ -254,6 +281,9 @@ FieldNode['decode'] = decodeFieldNode; BufferRegion['encode'] = encodeBufferRegion; BufferRegion['decode'] = decodeBufferRegion; +BodyCompression['encode'] = encodeBodyCompression; +BodyCompression['decode'] = decodeBodyCompression; + declare module '../../schema' { namespace Field { export { encodeField as encode }; @@ -286,6 +316,10 @@ declare module './message' { export { encodeBufferRegion as encode }; export { decodeBufferRegion as decode }; } + namespace BodyCompression { + export { encodeBodyCompression as encode }; + export { decodeBodyCompression as decode }; + } } /** @ignore */ @@ -296,10 +330,13 @@ function decodeSchema(_schema: _Schema, dictionaries: Map = ne /** @ignore */ function decodeRecordBatch(batch: _RecordBatch, version = MetadataVersion.V5) { - if (batch.compression() !== null) { - throw new Error('Record batch compression not implemented'); - } - return new RecordBatch(batch.length(), decodeFieldNodes(batch), decodeBuffers(batch, version)); + const recordBatch = new RecordBatch( + batch.length(), + decodeFieldNodes(batch), + decodeBuffers(batch, version), + decodeBodyCompression(batch.compression()) + ); + return recordBatch; } /** @ignore */ @@ -491,6 +528,11 @@ function decodeFieldType(f: _Field, children?: Field[]): DataType { throw new Error(`Unrecognized type: "${Type[typeId]}" (${typeId})`); } +/** @ignore */ +function decodeBodyCompression(b: _BodyCompression | null) { + return b ? new BodyCompression(b.codec(), b.method()) : null; +} + /** @ignore */ function encodeSchema(b: Builder, schema: Schema) { @@ -583,13 +625,29 @@ function encodeRecordBatch(b: Builder, recordBatch: RecordBatch) { const buffersVectorOffset = b.endVector(); + let bodyCompressionOffset = null; + if (recordBatch.compression !== null) { + bodyCompressionOffset = encodeBodyCompression(b, recordBatch.compression); + } + _RecordBatch.startRecordBatch(b); _RecordBatch.addLength(b, BigInt(recordBatch.length)); _RecordBatch.addNodes(b, nodesVectorOffset); _RecordBatch.addBuffers(b, buffersVectorOffset); + if (recordBatch.compression !== null && bodyCompressionOffset) { + _RecordBatch.addCompression(b, bodyCompressionOffset); + } return _RecordBatch.endRecordBatch(b); } +/** @ignore */ +function encodeBodyCompression(b: Builder, node: BodyCompression) { + _BodyCompression.startBodyCompression(b); + _BodyCompression.addCodec(b, node.type); + _BodyCompression.addMethod(b, node.method); + return _BodyCompression.endBodyCompression(b); +} + /** @ignore */ function encodeDictionaryBatch(b: Builder, dictionaryBatch: DictionaryBatch) { const dataOffset = RecordBatch.encode(b, dictionaryBatch.data); diff --git a/src/ipc/reader.ts b/src/ipc/reader.ts index f84fe83d..d99a6909 100644 --- a/src/ipc/reader.ts +++ b/src/ipc/reader.ts @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -import { makeData } from '../data.js'; +import { Data, makeData } from '../data.js'; import { Vector } from '../vector.js'; import { DataType, Struct, TypeMap } from '../type.js'; import { MessageHeader } from '../enum.js'; @@ -27,7 +27,7 @@ import * as metadata from './metadata/message.js'; import { ArrayBufferViewInput } from '../util/buffer.js'; import { ByteStream, AsyncByteStream } from '../io/stream.js'; import { RandomAccessFile, AsyncRandomAccessFile } from '../io/file.js'; -import { VectorLoader, JSONVectorLoader } from '../visitor/vectorloader.js'; +import { VectorLoader, JSONVectorLoader, CompressedVectorLoader } from '../visitor/vectorloader.js'; import { RecordBatch, _InternalEmptyPlaceholderRecordBatch } from '../recordbatch.js'; import { FileHandle, @@ -46,8 +46,12 @@ import { isFileHandle, isFetchResponse, isReadableDOMStream, isReadableNodeStream } from '../util/compat.js'; +import { Codec, compressionRegistry } from './compression/registry.js'; +import { bigIntToNumber } from './../util/bigint.js'; +import * as flatbuffers from 'flatbuffers'; import type { DuplexOptions, Duplex } from 'node:stream'; +import { COMPRESS_LENGTH_PREFIX, LENGTH_NO_COMPRESSED_DATA } from './compression/constants.js'; /** @ignore */ export type FromArg0 = ArrowJSONLike; /** @ignore */ export type FromArg1 = PromiseLike; @@ -354,12 +358,31 @@ abstract class RecordBatchReaderImpl implements RecordB return this; } - protected _loadRecordBatch(header: metadata.RecordBatch, body: any) { - const children = this._loadVectors(header, body, this.schema.fields); + protected _loadRecordBatch(header: metadata.RecordBatch, body: Uint8Array): RecordBatch { + let children: Data[]; + if (header.compression != null) { + const codec = compressionRegistry.get(header.compression.type); + if (codec?.decode && typeof codec.decode === 'function') { + const { decommpressedBody, buffers } = this._decompressBuffers(header, body, codec); + children = this._loadCompressedVectors(header, decommpressedBody, this.schema.fields); + header = new metadata.RecordBatch( + header.length, + header.nodes, + buffers, + null + ); + } else { + throw new Error('Record batch is compressed but codec not found'); + } + } else { + children = this._loadVectors(header, body, this.schema.fields); + } + const data = makeData({ type: new Struct(this.schema.fields), length: header.length, children }); return new RecordBatch(this.schema, data); } - protected _loadDictionaryBatch(header: metadata.DictionaryBatch, body: any) { + + protected _loadDictionaryBatch(header: metadata.DictionaryBatch, body: Uint8Array) { const { id, isDelta } = header; const { dictionaries, schema } = this; const dictionary = dictionaries.get(id); @@ -369,9 +392,48 @@ abstract class RecordBatchReaderImpl implements RecordB new Vector(data)) : new Vector(data)).memoize() as Vector; } - protected _loadVectors(header: metadata.RecordBatch, body: any, types: (Field | DataType)[]) { + + protected _loadVectors(header: metadata.RecordBatch, body: Uint8Array, types: (Field | DataType)[]) { return new VectorLoader(body, header.nodes, header.buffers, this.dictionaries, this.schema.metadataVersion).visitMany(types); } + + protected _loadCompressedVectors(header: metadata.RecordBatch, body: Uint8Array[], types: (Field | DataType)[]) { + return new CompressedVectorLoader(body, header.nodes, header.buffers, this.dictionaries, this.schema.metadataVersion).visitMany(types); + } + + private _decompressBuffers(header: metadata.RecordBatch, body: Uint8Array, codec: Codec): { decommpressedBody: Uint8Array[]; buffers: metadata.BufferRegion[] } { + const decompressedBuffers: Uint8Array[] = []; + const newBufferRegions: metadata.BufferRegion[] = []; + + let currentOffset = 0; + for (const { offset, length } of header.buffers) { + if (length === 0) { + decompressedBuffers.push(new Uint8Array(0)); + newBufferRegions.push(new metadata.BufferRegion(currentOffset, 0)); + continue; + } + const byteBuf = new flatbuffers.ByteBuffer(body.subarray(offset, offset + length)); + const uncompressedLenth = bigIntToNumber(byteBuf.readInt64(0)); + + const bytes = byteBuf.bytes().subarray(COMPRESS_LENGTH_PREFIX); + + const decompressed = (uncompressedLenth === LENGTH_NO_COMPRESSED_DATA) + ? bytes + : codec.decode!(bytes); + + decompressedBuffers.push(decompressed); + + const padding = ((currentOffset + 7) & ~7) - currentOffset; + currentOffset += padding; + newBufferRegions.push(new metadata.BufferRegion(currentOffset, decompressed.length)); + currentOffset += decompressed.length; + } + + return { + decommpressedBody: decompressedBuffers, + buffers: newBufferRegions + }; + } } /** @ignore */ diff --git a/src/ipc/serialization.ts b/src/ipc/serialization.ts index aee46762..c437ffeb 100644 --- a/src/ipc/serialization.ts +++ b/src/ipc/serialization.ts @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +import { CompressionType } from '../fb/compression-type.js'; import { Table } from '../table.js'; import { TypeMap } from '../type.js'; import { isPromise } from '../util/compat.js'; @@ -24,7 +25,7 @@ import { RecordBatchFileReader, RecordBatchStreamReader, AsyncRecordBatchFileReader, AsyncRecordBatchStreamReader } from './reader.js'; -import { RecordBatchFileWriter, RecordBatchStreamWriter } from './writer.js'; +import { RecordBatchFileWriter, RecordBatchStreamWriter, RecordBatchStreamWriterOptions } from './writer.js'; type RecordBatchReaders = RecordBatchFileReader | RecordBatchStreamReader; type AsyncRecordBatchReaders = AsyncRecordBatchFileReader | AsyncRecordBatchStreamReader; @@ -58,8 +59,9 @@ export function tableFromIPC(input: any): Table | Pr * @param table The Table to serialize. * @param type Whether to serialize the Table as a file or a stream. */ -export function tableToIPC(table: Table, type: 'file' | 'stream' = 'stream'): Uint8Array { +export function tableToIPC(table: Table, type: 'file' | 'stream' = 'stream', compressionType: CompressionType | null = null): Uint8Array { + const writerOptions: RecordBatchStreamWriterOptions = { compressionType }; return (type === 'stream' ? RecordBatchStreamWriter : RecordBatchFileWriter) - .writeAll(table) + .writeAll(table, writerOptions) .toUint8Array(true); } diff --git a/src/ipc/writer.ts b/src/ipc/writer.ts index cb74fe6f..a1d5415d 100644 --- a/src/ipc/writer.ts +++ b/src/ipc/writer.ts @@ -36,6 +36,10 @@ import { Writable, ReadableInterop, ReadableDOMStreamOptions } from '../io/inter import { isPromise, isAsyncIterable, isWritableDOMStream, isWritableNodeStream, isIterable, isObject } from '../util/compat.js'; import type { DuplexOptions, Duplex, ReadableOptions } from 'node:stream'; +import { CompressionType } from '../fb/compression-type.js'; +import { compressionRegistry } from './compression/registry.js'; +import { LENGTH_NO_COMPRESSED_DATA, COMPRESS_LENGTH_PREFIX } from './compression/constants.js'; +import * as flatbuffers from 'flatbuffers'; export interface RecordBatchStreamWriterOptions { /** @@ -49,6 +53,10 @@ export interface RecordBatchStreamWriterOptions { * @see https://issues.apache.org/jira/browse/ARROW-6313 */ writeLegacyIpcFormat?: boolean; + /** + * Specifies the optional compression algorithm to use for record batch body buffers. + */ + compressionType?: CompressionType | null; } export class RecordBatchWriter extends ReadableInterop implements Writable> { @@ -70,15 +78,30 @@ export class RecordBatchWriter extends ReadableInterop< constructor(options?: RecordBatchStreamWriterOptions) { super(); - isObject(options) || (options = { autoDestroy: true, writeLegacyIpcFormat: false }); + isObject(options) || (options = { autoDestroy: true, writeLegacyIpcFormat: false, compressionType: null }); this._autoDestroy = (typeof options.autoDestroy === 'boolean') ? options.autoDestroy : true; this._writeLegacyIpcFormat = (typeof options.writeLegacyIpcFormat === 'boolean') ? options.writeLegacyIpcFormat : false; + if (options.compressionType != null) { + if (this._writeLegacyIpcFormat) { + throw new Error('Legacy IPC format does not support columnar compression. Use modern IPC format (writeLegacyIpcFormat=false).'); + } + if (Object.values(CompressionType).includes(options.compressionType)) { + this._compression = new metadata.BodyCompression(options.compressionType); + } else { + const validCompressionTypes = Object.values(CompressionType) + .filter((v): v is string => typeof v === 'string'); + throw new Error(`Unsupported compressionType: ${options.compressionType} Available types: ${validCompressionTypes.join(', ')}`); + } + } else { + this._compression = null; + } } protected _position = 0; protected _started = false; protected _autoDestroy: boolean; protected _writeLegacyIpcFormat: boolean; + protected _compression: metadata.BodyCompression | null = null; // @ts-ignore protected _sink = new AsyncByteQueue(); protected _schema: Schema | null = null; @@ -251,8 +274,8 @@ export class RecordBatchWriter extends ReadableInterop< } protected _writeRecordBatch(batch: RecordBatch) { - const { byteLength, nodes, bufferRegions, buffers } = VectorAssembler.assemble(batch); - const recordBatch = new metadata.RecordBatch(batch.numRows, nodes, bufferRegions); + const { byteLength, nodes, bufferRegions, buffers } = this._assembleRecordBatch(batch); + const recordBatch = new metadata.RecordBatch(batch.numRows, nodes, bufferRegions, this._compression); const message = Message.from(recordBatch, byteLength); return this ._writeDictionaries(batch) @@ -260,25 +283,90 @@ export class RecordBatchWriter extends ReadableInterop< ._writeBodyBuffers(buffers); } + protected _assembleRecordBatch(batch: RecordBatch) { + let { byteLength, nodes, bufferRegions, buffers } = VectorAssembler.assemble(batch); + if (this._compression != null) { + ({ byteLength, bufferRegions, buffers } = this._compressBodyBuffers(buffers)); + } + return { byteLength, nodes, bufferRegions, buffers }; + } + + protected _compressBodyBuffers(buffers: ArrayBufferView[]) { + const codec = compressionRegistry.get(this._compression!.type!); + + if (!codec?.encode || typeof codec.encode !== 'function') { + throw new Error(`Codec for compression type "${CompressionType[this._compression!.type!]}" has invalid encode method`); + } + + let currentOffset = 0; + const compressedBuffers: ArrayBufferView[] = []; + const bufferRegions: metadata.BufferRegion[] = []; + + for (const buffer of buffers) { + const byteBuf = toUint8Array(buffer); + + if (byteBuf.length === 0) { + compressedBuffers.push(new Uint8Array(0), new Uint8Array(0)); + bufferRegions.push(new metadata.BufferRegion(currentOffset, 0)); + continue; + } + + const compressed = codec.encode(byteBuf); + const isCompressionEffective = compressed.length < byteBuf.length; + + const finalBuffer = isCompressionEffective ? compressed : byteBuf; + const byteLength = isCompressionEffective ? finalBuffer.length : LENGTH_NO_COMPRESSED_DATA; + + const lengthPrefix = new flatbuffers.ByteBuffer(new Uint8Array(COMPRESS_LENGTH_PREFIX)); + lengthPrefix.writeInt64(0, BigInt(byteLength)); + + compressedBuffers.push(lengthPrefix.bytes(), new Uint8Array(finalBuffer)); + + const padding = ((currentOffset + 7) & ~7) - currentOffset; + currentOffset += padding; + + const fullBodyLength = COMPRESS_LENGTH_PREFIX + finalBuffer.length; + bufferRegions.push(new metadata.BufferRegion(currentOffset, fullBodyLength)); + + currentOffset += fullBodyLength; + } + const finalPadding = ((currentOffset + 7) & ~7) - currentOffset; + currentOffset += finalPadding; + + return { byteLength: currentOffset, bufferRegions, buffers: compressedBuffers }; + } + protected _writeDictionaryBatch(dictionary: Data, id: number, isDelta = false) { const { byteLength, nodes, bufferRegions, buffers } = VectorAssembler.assemble(new Vector([dictionary])); - const recordBatch = new metadata.RecordBatch(dictionary.length, nodes, bufferRegions); + const recordBatch = new metadata.RecordBatch(dictionary.length, nodes, bufferRegions, null); const dictionaryBatch = new metadata.DictionaryBatch(recordBatch, id, isDelta); const message = Message.from(dictionaryBatch, byteLength); return this ._writeMessage(message) - ._writeBodyBuffers(buffers); + ._writeBodyBuffers(buffers, "dictionary"); } - protected _writeBodyBuffers(buffers: ArrayBufferView[]) { - let buffer: ArrayBufferView; - let size: number, padding: number; - for (let i = -1, n = buffers.length; ++i < n;) { - if ((buffer = buffers[i]) && (size = buffer.byteLength) > 0) { - this._write(buffer); - if ((padding = ((size + 7) & ~7) - size) > 0) { - this._writePadding(padding); - } + protected _writeBodyBuffers(buffers: ArrayBufferView[], batchType: "record" | "dictionary" = "record") { + const bufGroupSize = batchType === "dictionary" + ? 1 + : this._compression != null ? 2 : 1; + const bufs = new Array(bufGroupSize); + + for (let i = 0; i < buffers.length; i += bufGroupSize) { + let size = 0; + for (let j = -1; ++j < bufGroupSize;) { + bufs[j] = buffers[i + j]; + size += bufs[j].byteLength; + } + + if (size === 0) { + continue; + } + + for (const buf of bufs) this._write(buf); + const padding = ((size + 7) & ~7) - size; + if (padding > 0) { + this._writePadding(padding); } } return this; @@ -325,13 +413,13 @@ export class RecordBatchStreamWriter extends RecordBatc /** @ignore */ export class RecordBatchFileWriter extends RecordBatchWriter { - public static writeAll(input: Table | Iterable>): RecordBatchFileWriter; - public static writeAll(input: AsyncIterable>): Promise>; - public static writeAll(input: PromiseLike>>): Promise>; - public static writeAll(input: PromiseLike | Iterable>>): Promise>; + public static writeAll(input: Table | Iterable>, options?: RecordBatchStreamWriterOptions): RecordBatchFileWriter; + public static writeAll(input: AsyncIterable>, options?: RecordBatchStreamWriterOptions): Promise>; + public static writeAll(input: PromiseLike>>, options?: RecordBatchStreamWriterOptions): Promise>; + public static writeAll(input: PromiseLike | Iterable>>, options?: RecordBatchStreamWriterOptions): Promise>; /** @nocollapse */ - public static writeAll(input: any) { - const writer = new RecordBatchFileWriter(); + public static writeAll(input: any, options?: RecordBatchStreamWriterOptions) { + const writer = new RecordBatchFileWriter(options); if (isPromise(input)) { return input.then((x) => writer.writeAll(x)); } else if (isAsyncIterable>(input)) { @@ -340,9 +428,10 @@ export class RecordBatchFileWriter extends RecordBatchW return writeAll(writer, input); } - constructor() { - super(); + constructor(options?: RecordBatchStreamWriterOptions) { + super(options); this._autoDestroy = true; + this._writeLegacyIpcFormat = false; } // @ts-ignore diff --git a/src/visitor/vectorloader.ts b/src/visitor/vectorloader.ts index 198c32ff..7c82e7ab 100644 --- a/src/visitor/vectorloader.ts +++ b/src/visitor/vectorloader.ts @@ -41,7 +41,7 @@ export class VectorLoader extends Visitor { private nodes: FieldNode[]; private nodesIndex = -1; private buffers: BufferRegion[]; - private buffersIndex = -1; + protected buffersIndex = -1; private dictionaries: Map>; private readonly metadataVersion: MetadataVersion; constructor(bytes: Uint8Array, nodes: FieldNode[], buffers: BufferRegion[], dictionaries: Map>, metadataVersion = MetadataVersion.V5) { @@ -205,3 +205,14 @@ function binaryDataFromJSON(values: string[]) { } return data; } + +export class CompressedVectorLoader extends VectorLoader { + private bodyChunks: Uint8Array[]; + constructor(bodyChunks: Uint8Array[], nodes: FieldNode[], buffers: BufferRegion[], dictionaries: Map>, metadataVersion: MetadataVersion) { + super(new Uint8Array(0), nodes, buffers, dictionaries, metadataVersion); + this.bodyChunks = bodyChunks; + } + protected readData(_type: T, _buffer = this.nextBufferRange()) { + return this.bodyChunks[this.buffersIndex]; + } +} diff --git a/test/tsconfig.json b/test/tsconfig.json index bd43e091..e1ad1388 100644 --- a/test/tsconfig.json +++ b/test/tsconfig.json @@ -17,6 +17,7 @@ "inlineSourceMap": false, "downlevelIteration": false, "baseUrl": "../", + "typeRoots": ["../node_modules/@types", "./types"], "paths": { "apache-arrow": ["src/Arrow.node"], "apache-arrow/*": ["src/*"] diff --git a/test/types/zstd-codec.d.ts b/test/types/zstd-codec.d.ts new file mode 100644 index 00000000..76176d36 --- /dev/null +++ b/test/types/zstd-codec.d.ts @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +declare module 'zstd-codec' { + export const ZstdCodec: { + run(callback: (zstd: any) => void): void; + }; +} diff --git a/test/unit/ipc/writer/file-writer-tests.ts b/test/unit/ipc/writer/file-writer-tests.ts index 2b99d0f7..5fbcd20e 100644 --- a/test/unit/ipc/writer/file-writer-tests.ts +++ b/test/unit/ipc/writer/file-writer-tests.ts @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +import { RecordBatchStreamWriterOptions } from 'apache-arrow/ipc/writer.js'; import { generateDictionaryTables, generateRandomTables } from '../../../data/tables.js'; @@ -23,6 +24,7 @@ import { validateRecordBatchIterator } from '../validate.js'; import { builderThroughIterable, + CompressionType, Dictionary, Int32, RecordBatch, @@ -32,6 +34,41 @@ import { Uint32, Vector } from 'apache-arrow'; +import { Codec, compressionRegistry } from 'apache-arrow/ipc/compression/registry'; +import * as lz4js from 'lz4js'; + +export async function registerCompressionCodecs(): Promise { + if (compressionRegistry.get(CompressionType.LZ4_FRAME) === null) { + const lz4Codec: Codec = { + encode(data: Uint8Array): Uint8Array { + return lz4js.compress(data); + }, + decode(data: Uint8Array): Uint8Array { + return lz4js.decompress(data); + } + }; + compressionRegistry.set(CompressionType.LZ4_FRAME, lz4Codec); + } + + if (compressionRegistry.get(CompressionType.ZSTD) === null) { + const { ZstdCodec } = await import('zstd-codec'); + await new Promise((resolve) => { + ZstdCodec.run((zstd: any) => { + const simple = new zstd.Simple(); + const zstdCodec: Codec = { + encode(data: Uint8Array): Uint8Array { + return simple.compress(data); + }, + decode(data: Uint8Array): Uint8Array { + return simple.decompress(data); + } + }; + compressionRegistry.set(CompressionType.ZSTD, zstdCodec); + resolve(); + }); + }); + } +} describe('RecordBatchFileWriter', () => { for (const table of generateRandomTables([10, 20, 30])) { @@ -41,6 +78,17 @@ describe('RecordBatchFileWriter', () => { testFileWriter(table, `${table.schema.fields[0]}`); } + const compressionTypes = [CompressionType.LZ4_FRAME, CompressionType.ZSTD]; + beforeAll(async () => { + await registerCompressionCodecs(); + }); + + const table = generate.table([10, 20, 30]).table; + for (const compressionType of compressionTypes) { + const testName = `[${table.schema.fields.join(', ')}] - ${CompressionType[compressionType]} compressed`; + testFileWriter(table, testName, { compressionType }); + } + it('should throw if attempting to write replacement dictionary batches', async () => { const type = new Dictionary(new Uint32, new Int32, 0); const writer = new RecordBatchFileWriter(); @@ -91,14 +139,14 @@ describe('RecordBatchFileWriter', () => { }); }); -function testFileWriter(table: Table, name: string) { +function testFileWriter(table: Table, name: string, options?: RecordBatchStreamWriterOptions) { describe(`should write the Arrow IPC file format (${name})`, () => { - test(`Table`, validateTable.bind(0, table)); + test(`Table`, validateTable.bind(0, table, options)); }); } -async function validateTable(source: Table) { - const writer = RecordBatchFileWriter.writeAll(source); +async function validateTable(source: Table, options?: RecordBatchStreamWriterOptions) { + const writer = RecordBatchFileWriter.writeAll(source, options); const result = new Table(RecordBatchReader.from(await writer.toUint8Array())); validateRecordBatchIterator(3, source.batches); expect(result).toEqualTable(source); diff --git a/test/unit/ipc/writer/stream-writer-tests.ts b/test/unit/ipc/writer/stream-writer-tests.ts index 11bbe736..e32af095 100644 --- a/test/unit/ipc/writer/stream-writer-tests.ts +++ b/test/unit/ipc/writer/stream-writer-tests.ts @@ -25,6 +25,7 @@ import { validateRecordBatchIterator } from '../validate.js'; import type { RecordBatchStreamWriterOptions } from 'apache-arrow/ipc/writer'; import { builderThroughIterable, + CompressionType, Data, Dictionary, Field, @@ -37,6 +38,41 @@ import { Uint32, Vector } from 'apache-arrow'; +import { Codec, compressionRegistry } from 'apache-arrow/ipc/compression/registry'; +import * as lz4js from 'lz4js'; + +export async function registerCompressionCodecs(): Promise { + if (compressionRegistry.get(CompressionType.LZ4_FRAME) === null) { + const lz4Codec: Codec = { + encode(data: Uint8Array): Uint8Array { + return lz4js.compress(data); + }, + decode(data: Uint8Array): Uint8Array { + return lz4js.decompress(data); + } + }; + compressionRegistry.set(CompressionType.LZ4_FRAME, lz4Codec); + } + + if (compressionRegistry.get(CompressionType.ZSTD) === null) { + const { ZstdCodec } = await import('zstd-codec'); + await new Promise((resolve) => { + ZstdCodec.run((zstd: any) => { + const simple = new zstd.Simple(); + const zstdCodec: Codec = { + encode(data: Uint8Array): Uint8Array { + return simple.compress(data); + }, + decode(data: Uint8Array): Uint8Array { + return simple.decompress(data); + } + }; + compressionRegistry.set(CompressionType.ZSTD, zstdCodec); + resolve(); + }); + }); + } +} describe('RecordBatchStreamWriter', () => { @@ -47,6 +83,16 @@ describe('RecordBatchStreamWriter', () => { testStreamWriter(table, testName, { writeLegacyIpcFormat: true }); testStreamWriter(table, testName, { writeLegacyIpcFormat: false }); + const compressionTypes = [CompressionType.LZ4_FRAME, CompressionType.ZSTD]; + beforeAll(async () => { + await registerCompressionCodecs(); + }); + + for (const compressionType of compressionTypes) { + const testName = `[${table.schema.fields.join(', ')}] - ${CompressionType[compressionType]} compressed`; + testStreamWriter(table, testName, { compressionType }); + } + for (const table of generateRandomTables([10, 20, 30])) { const testName = `[${table.schema.fields.join(', ')}]`; testStreamWriter(table, testName, { writeLegacyIpcFormat: true }); diff --git a/yarn.lock b/yarn.lock index 3d0c694f..7da66143 100644 --- a/yarn.lock +++ b/yarn.lock @@ -1390,6 +1390,11 @@ resolved "https://registry.yarnpkg.com/@types/json-schema/-/json-schema-7.0.15.tgz#596a1747233694d50f6ad8a7869fcb6f56cf5841" integrity sha512-5+fP8P8MFNC+AyZCDxrB2pkZFPGzqQWUzpSeuuVLvm8VMcorNYavBqoFcxK8bQz4Qsbn4oUEEem4wDLfcysGHA== +"@types/lz4js@0.2.1": + version "0.2.1" + resolved "https://registry.yarnpkg.com/@types/lz4js/-/lz4js-0.2.1.tgz#44214fe6b28187ff36eee03afc2b344cbd886b3e" + integrity sha512-aAnbA4uKPNqZqu0XK1QAwKP0Wskb4Oa7ZFqxW5CMIyGgqYQKFgBxTfK3m3KODXoOLv5t14VregzgrEak13uGQA== + "@types/minimatch@^5.1.2": version "5.1.2" resolved "https://registry.yarnpkg.com/@types/minimatch/-/minimatch-5.1.2.tgz#07508b45797cb81ec3f273011b054cd0755eddca" @@ -5024,6 +5029,11 @@ lunr@^2.3.9: resolved "https://registry.yarnpkg.com/lunr/-/lunr-2.3.9.tgz#18b123142832337dd6e964df1a5a7707b25d35e1" integrity sha512-zTU3DaZaF3Rt9rhN3uBMGQD3dD2/vFQqnvZCDv4dl5iOzq2IZQqTxu90r4E5J+nP70J3ilqVCrbho2eWaeW8Ow== +lz4js@0.2.0: + version "0.2.0" + resolved "https://registry.yarnpkg.com/lz4js/-/lz4js-0.2.0.tgz#09f1a397cb2158f675146c3351dde85058cb322f" + integrity sha512-gY2Ia9Lm7Ep8qMiuGRhvUq0Q7qUereeldZPP1PMEJxPtEWHJLqw9pgX68oHajBH0nzJK4MaZEA/YNV3jT8u8Bg== + make-dir@^4.0.0: version "4.0.0" resolved "https://registry.yarnpkg.com/make-dir/-/make-dir-4.0.0.tgz#c3c2307a771277cd9638305f915c29ae741b614e" @@ -7320,3 +7330,8 @@ yocto-queue@^0.1.0: version "0.1.0" resolved "https://registry.yarnpkg.com/yocto-queue/-/yocto-queue-0.1.0.tgz#0294eb3dee05028d31ee1a5fa2c556a6aaf10a1b" integrity sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q== + +zstd-codec@0.1.5: + version "0.1.5" + resolved "https://registry.yarnpkg.com/zstd-codec/-/zstd-codec-0.1.5.tgz#c180193e4603ef74ddf704bcc835397d30a60e42" + integrity sha512-v3fyjpK8S/dpY/X5WxqTK3IoCnp/ZOLxn144GZVlNUjtwAchzrVo03h+oMATFhCIiJ5KTr4V3vDQQYz4RU684g==