From f4cd026efd6bc80ff2430e7a5f0281f01e6986e2 Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Tue, 29 Apr 2025 19:30:57 -0700 Subject: [PATCH 01/55] Add Python files --- Lib/compression/zstd/__init__.py | 286 +++ Lib/compression/zstd/zstdfile.py | 378 ++++ Lib/shutil.py | 19 +- Lib/tarfile.py | 61 +- Lib/test/support/__init__.py | 9 +- Lib/test/test_shutil.py | 4 + Lib/test/test_tarfile.py | 44 +- Lib/test/test_zipfile/test_core.py | 28 +- Lib/test/test_zstd/__init__.py | 5 + Lib/test/test_zstd/__main__.py | 7 + Lib/test/test_zstd/test_core.py | 2693 ++++++++++++++++++++++++++++ Lib/zipfile/__init__.py | 20 +- Makefile.pre.in | 3 +- 13 files changed, 3548 insertions(+), 9 deletions(-) create mode 100644 Lib/compression/zstd/__init__.py create mode 100644 Lib/compression/zstd/zstdfile.py create mode 100644 Lib/test/test_zstd/__init__.py create mode 100644 Lib/test/test_zstd/__main__.py create mode 100644 Lib/test/test_zstd/test_core.py diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py new file mode 100644 index 00000000000000..731da1a9598392 --- /dev/null +++ b/Lib/compression/zstd/__init__.py @@ -0,0 +1,286 @@ +"""Python bindings to Zstandard (zstd) compression library, the API style is +similar to Python's bz2/lzma/zlib modules. +""" + +__all__ = ( + # From this file + "compressionLevel_values", + "get_frame_info", + "CParameter", + "DParameter", + "Strategy", + "finalize_dict", + "train_dict", + "zstd_support_multithread", + "compress", + "decompress", + # From _zstd + "ZstdCompressor", + "ZstdDecompressor", + "ZstdDict", + "ZstdError", + "get_frame_size", + "zstd_version", + "zstd_version_info", + # From zstd.zstdfile + "open", + "ZstdFile", +) + +from collections import namedtuple +from enum import IntEnum +from functools import lru_cache + +from compression.zstd.zstdfile import ZstdFile, open +from _zstd import * + +import _zstd + + +_ZSTD_CStreamSizes = _zstd._ZSTD_CStreamSizes +_ZSTD_DStreamSizes = _zstd._ZSTD_DStreamSizes +_train_dict = _zstd._train_dict +_finalize_dict = _zstd._finalize_dict + + +# TODO(emmatyping): these should be dataclasses or some other class, not namedtuples + +# compressionLevel_values +_nt_values = namedtuple("values", ["default", "min", "max"]) +compressionLevel_values = _nt_values(*_zstd._compressionLevel_values) + + +_nt_frame_info = namedtuple("frame_info", ["decompressed_size", "dictionary_id"]) + + +def get_frame_info(frame_buffer): + """Get zstd frame information from a frame header. + + Parameter + frame_buffer: A bytes-like object. It should starts from the beginning of + a frame, and needs to include at least the frame header (6 to + 18 bytes). + + Return a two-items namedtuple: (decompressed_size, dictionary_id) + + If decompressed_size is None, decompressed size is unknown. + + dictionary_id is a 32-bit unsigned integer value. 0 means dictionary ID was + not recorded in the frame header, the frame may or may not need a dictionary + to be decoded, and the ID of such a dictionary is not specified. + + It's possible to append more items to the namedtuple in the future.""" + + ret_tuple = _zstd._get_frame_info(frame_buffer) + return _nt_frame_info(*ret_tuple) + + +def _nbytes(dat): + if isinstance(dat, (bytes, bytearray)): + return len(dat) + with memoryview(dat) as mv: + return mv.nbytes + + +def train_dict(samples, dict_size): + """Train a zstd dictionary, return a ZstdDict object. + + Parameters + samples: An iterable of samples, a sample is a bytes-like object + represents a file. + dict_size: The dictionary's maximum size, in bytes. + """ + # Check argument's type + if not isinstance(dict_size, int): + raise TypeError('dict_size argument should be an int object.') + + # Prepare data + chunks = [] + chunk_sizes = [] + for chunk in samples: + chunks.append(chunk) + chunk_sizes.append(_nbytes(chunk)) + + chunks = b''.join(chunks) + if not chunks: + raise ValueError("The samples are empty content, can't train dictionary.") + + # samples_bytes: samples be stored concatenated in a single flat buffer. + # samples_size_list: a list of each sample's size. + # dict_size: size of the dictionary, in bytes. + dict_content = _train_dict(chunks, chunk_sizes, dict_size) + + return ZstdDict(dict_content) + + +def finalize_dict(zstd_dict, samples, dict_size, level): + """Finalize a zstd dictionary, return a ZstdDict object. + + Given a custom content as a basis for dictionary, and a set of samples, + finalize dictionary by adding headers and statistics according to the zstd + dictionary format. + + You may compose an effective dictionary content by hand, which is used as + basis dictionary, and use some samples to finalize a dictionary. The basis + dictionary can be a "raw content" dictionary, see is_raw parameter in + ZstdDict.__init__ method. + + Parameters + zstd_dict: A ZstdDict object, basis dictionary. + samples: An iterable of samples, a sample is a bytes-like object + represents a file. + dict_size: The dictionary's maximum size, in bytes. + level: The compression level expected to use in production. The + statistics for each compression level differ, so tuning the + dictionary for the compression level can help quite a bit. + """ + + # Check arguments' type + if not isinstance(zstd_dict, ZstdDict): + raise TypeError('zstd_dict argument should be a ZstdDict object.') + if not isinstance(dict_size, int): + raise TypeError('dict_size argument should be an int object.') + if not isinstance(level, int): + raise TypeError('level argument should be an int object.') + + # Prepare data + chunks = [] + chunk_sizes = [] + for chunk in samples: + chunks.append(chunk) + chunk_sizes.append(_nbytes(chunk)) + + chunks = b''.join(chunks) + if not chunks: + raise ValueError("The samples are empty content, can't finalize dictionary.") + + # custom_dict_bytes: existing dictionary. + # samples_bytes: samples be stored concatenated in a single flat buffer. + # samples_size_list: a list of each sample's size. + # dict_size: maximal size of the dictionary, in bytes. + # compression_level: compression level expected to use in production. + dict_content = _finalize_dict(zstd_dict.dict_content, + chunks, chunk_sizes, + dict_size, level) + + return _zstd.ZstdDict(dict_content) + +def compress(data, level=None, options=None, zstd_dict=None): + """Compress a block of data, return a bytes object of zstd compressed data. + + Refer to ZstdCompressor's docstring for a description of the + optional arguments *level*, *options*, and *zstd_dict*. + + For incremental compression, use an ZstdCompressor instead. + """ + comp = ZstdCompressor(level=level, options=options, zstd_dict=zstd_dict) + return comp.compress(data, ZstdCompressor.FLUSH_FRAME) + +def decompress(data, zstd_dict=None, options=None): + """Decompress one or more frames of data. + + Refer to ZstdDecompressor's docstring for a description of the + optional arguments *zstd_dict*, *options*. + + For incremental decompression, use an ZstdDecompressor instead. + """ + results = [] + while True: + decomp = ZstdDecompressor(options=options, zstd_dict=zstd_dict) + try: + res = decomp.decompress(data) + except ZstdError: + if results: + break # Leftover data is not a valid LZMA/XZ stream; ignore it. + else: + raise # Error on the first iteration; bail out. + results.append(res) + if not decomp.eof: + raise ZstdError("Compressed data ended before the " + "end-of-stream marker was reached") + data = decomp.unused_data + if not data: + break + return b"".join(results) + +class _UnsupportedCParameter: + def __set_name__(self, _, name): + self.name = name + + def __get__(self, *_, **__): + msg = ("%s CParameter not available, zstd version is %s.") % ( + self.name, + zstd_version, + ) + raise NotImplementedError(msg) + + +class CParameter(IntEnum): + """Compression parameters""" + + compressionLevel = _zstd._ZSTD_c_compressionLevel + windowLog = _zstd._ZSTD_c_windowLog + hashLog = _zstd._ZSTD_c_hashLog + chainLog = _zstd._ZSTD_c_chainLog + searchLog = _zstd._ZSTD_c_searchLog + minMatch = _zstd._ZSTD_c_minMatch + targetLength = _zstd._ZSTD_c_targetLength + strategy = _zstd._ZSTD_c_strategy + + targetCBlockSize = _UnsupportedCParameter() + + enableLongDistanceMatching = _zstd._ZSTD_c_enableLongDistanceMatching + ldmHashLog = _zstd._ZSTD_c_ldmHashLog + ldmMinMatch = _zstd._ZSTD_c_ldmMinMatch + ldmBucketSizeLog = _zstd._ZSTD_c_ldmBucketSizeLog + ldmHashRateLog = _zstd._ZSTD_c_ldmHashRateLog + + contentSizeFlag = _zstd._ZSTD_c_contentSizeFlag + checksumFlag = _zstd._ZSTD_c_checksumFlag + dictIDFlag = _zstd._ZSTD_c_dictIDFlag + + nbWorkers = _zstd._ZSTD_c_nbWorkers + jobSize = _zstd._ZSTD_c_jobSize + overlapLog = _zstd._ZSTD_c_overlapLog + + @lru_cache(maxsize=None) + def bounds(self): + """Return lower and upper bounds of a compression parameter, both inclusive.""" + # 1 means compression parameter + return _zstd._get_param_bounds(1, self.value) + + +class DParameter(IntEnum): + """Decompression parameters""" + + windowLogMax = _zstd._ZSTD_d_windowLogMax + + @lru_cache(maxsize=None) + def bounds(self): + """Return lower and upper bounds of a decompression parameter, both inclusive.""" + # 0 means decompression parameter + return _zstd._get_param_bounds(0, self.value) + + +class Strategy(IntEnum): + """Compression strategies, listed from fastest to strongest. + + Note : new strategies _might_ be added in the future, only the order + (from fast to strong) is guaranteed. + """ + + fast = _zstd._ZSTD_fast + dfast = _zstd._ZSTD_dfast + greedy = _zstd._ZSTD_greedy + lazy = _zstd._ZSTD_lazy + lazy2 = _zstd._ZSTD_lazy2 + btlazy2 = _zstd._ZSTD_btlazy2 + btopt = _zstd._ZSTD_btopt + btultra = _zstd._ZSTD_btultra + btultra2 = _zstd._ZSTD_btultra2 + + +# Set CParameter/DParameter types for validity check +_zstd._set_parameter_types(CParameter, DParameter) + +zstd_support_multithread = CParameter.nbWorkers.bounds() != (0, 0) diff --git a/Lib/compression/zstd/zstdfile.py b/Lib/compression/zstd/zstdfile.py new file mode 100644 index 00000000000000..1ca60fe5677454 --- /dev/null +++ b/Lib/compression/zstd/zstdfile.py @@ -0,0 +1,378 @@ +import builtins +import io + +from os import PathLike + +from _zstd import ZstdCompressor, ZstdDecompressor, _ZSTD_DStreamSizes, ZstdError +from compression._common import _streams + +__all__ = ("ZstdFile", "open") + +_ZSTD_DStreamOutSize = _ZSTD_DStreamSizes[1] + +_MODE_CLOSED = 0 +_MODE_READ = 1 +_MODE_WRITE = 2 + + +class ZstdFile(_streams.BaseStream): + """A file object providing transparent zstd (de)compression. + + A ZstdFile can act as a wrapper for an existing file object, or refer + directly to a named file on disk. + + Note that ZstdFile provides a *binary* file interface - data read is + returned as bytes, and data to be written should be an object that + supports the Buffer Protocol. + """ + + _READER_CLASS = _streams.DecompressReader + + FLUSH_BLOCK = ZstdCompressor.FLUSH_BLOCK + FLUSH_FRAME = ZstdCompressor.FLUSH_FRAME + + def __init__( + self, + filename, + mode="r", + *, + level=None, + options=None, + zstd_dict=None, + ): + """Open a zstd compressed file in binary mode. + + filename can be either an actual file name (given as a str, bytes, or + PathLike object), in which case the named file is opened, or it can be + an existing file object to read from or write to. + + mode can be "r" for reading (default), "w" for (over)writing, "x" for + creating exclusively, or "a" for appending. These can equivalently be + given as "rb", "wb", "xb" and "ab" respectively. + + Parameters + level: The compression level to use, defaults to ZSTD_CLEVEL_DEFAULT. Note, + in read mode (decompression), compression level is not supported. + options: A dict object, containing advanced compression + parameters. + zstd_dict: A ZstdDict object, pre-trained dictionary for compression / + decompression. + """ + self._fp = None + self._closefp = False + self._mode = _MODE_CLOSED + + # Read or write mode + if mode in ("r", "rb"): + if not isinstance(options, (type(None), dict)): + raise TypeError( + ( + "In read mode (decompression), options argument " + "should be a dict object, that represents decompression " + "options." + ) + ) + if level: + raise TypeError("level argument should only be passed when writing.") + mode_code = _MODE_READ + elif mode in ("w", "wb", "a", "ab", "x", "xb"): + if not isinstance(level, (type(None), int)): + raise TypeError(("level argument should be an int object.")) + if not isinstance(options, (type(None), dict)): + raise TypeError(("options argument should be an dict object.")) + mode_code = _MODE_WRITE + self._compressor = ZstdCompressor( + level=level, options=options, zstd_dict=zstd_dict + ) + self._pos = 0 + else: + raise ValueError("Invalid mode: {!r}".format(mode)) + + # File object + if isinstance(filename, (str, bytes, PathLike)): + if "b" not in mode: + mode += "b" + self._fp = builtins.open(filename, mode) + self._closefp = True + elif hasattr(filename, "read") or hasattr(filename, "write"): + self._fp = filename + else: + raise TypeError(("filename must be a str, bytes, file or PathLike object")) + self._mode = mode_code + + if self._mode == _MODE_READ: + raw = self._READER_CLASS( + self._fp, + ZstdDecompressor, + trailing_error=ZstdError, + zstd_dict=zstd_dict, + options=options, + ) + self._buffer = io.BufferedReader(raw) + + def close(self): + """Flush and close the file. + + May be called more than once without error. Once the file is + closed, any other operation on it will raise a ValueError. + """ + # Nop if already closed + if self._fp is None: + return + try: + if self._mode == _MODE_READ: + if hasattr(self, "_buffer") and self._buffer: + self._buffer.close() + self._buffer = None + elif self._mode == _MODE_WRITE: + self.flush(self.FLUSH_FRAME) + self._compressor = None + finally: + self._mode = _MODE_CLOSED + try: + if self._closefp: + self._fp.close() + finally: + self._fp = None + self._closefp = False + + def write(self, data): + """Write a bytes-like object to the file. + + Returns the number of uncompressed bytes written, which is + always the length of data in bytes. Note that due to buffering, + the file on disk may not reflect the data written until .flush() + or .close() is called. + """ + self._check_can_write() + if isinstance(data, (bytes, bytearray)): + length = len(data) + else: + # accept any data that supports the buffer protocol + data = memoryview(data) + length = data.nbytes + + compressed = self._compressor.compress(data) + self._fp.write(compressed) + self._pos += length + return length + + def flush(self, mode=FLUSH_BLOCK): + """Flush remaining data to the underlying stream. + + The mode argument can be ZstdFile.FLUSH_BLOCK, ZstdFile.FLUSH_FRAME. + Abuse of this method will reduce compression ratio, use it only when + necessary. + + If the program is interrupted afterwards, all data can be recovered. + To ensure saving to disk, also need to use os.fsync(fd). + + This method does nothing in reading mode. + """ + if self._mode == _MODE_READ: + return + self._check_not_closed() + if mode not in (self.FLUSH_BLOCK, self.FLUSH_FRAME): + raise ValueError("mode argument wrong value, it should be " + "ZstdCompressor.FLUSH_FRAME or " + "ZstdCompressor.FLUSH_BLOCK.") + if self._compressor.last_mode == mode: + return + # Flush zstd block/frame, and write. + data = self._compressor.flush(mode) + self._fp.write(data) + if hasattr(self._fp, "flush"): + self._fp.flush() + + def read(self, size=-1): + """Read up to size uncompressed bytes from the file. + + If size is negative or omitted, read until EOF is reached. + Returns b"" if the file is already at EOF. + """ + if size is None: + size = -1 + self._check_can_read() + return self._buffer.read(size) + + def read1(self, size=-1): + """Read up to size uncompressed bytes, while trying to avoid + making multiple reads from the underlying stream. Reads up to a + buffer's worth of data if size is negative. + + Returns b"" if the file is at EOF. + """ + self._check_can_read() + if size < 0: + # Note this should *not* be io.DEFAULT_BUFFER_SIZE. + # ZSTD_DStreamOutSize is the minimum amount to read guaranteeing + # a full block is read. + size = _ZSTD_DStreamOutSize + return self._buffer.read1(size) + + def readinto(self, b): + """Read bytes into b. + + Returns the number of bytes read (0 for EOF). + """ + self._check_can_read() + return self._buffer.readinto(b) + + def readinto1(self, b): + """Read bytes into b, while trying to avoid making multiple reads + from the underlying stream. + + Returns the number of bytes read (0 for EOF). + """ + self._check_can_read() + return self._buffer.readinto1(b) + + def readline(self, size=-1): + """Read a line of uncompressed bytes from the file. + + The terminating newline (if present) is retained. If size is + non-negative, no more than size bytes will be read (in which + case the line may be incomplete). Returns b'' if already at EOF. + """ + self._check_can_read() + return self._buffer.readline(size) + + def seek(self, offset, whence=io.SEEK_SET): + """Change the file position. + + The new position is specified by offset, relative to the + position indicated by whence. Possible values for whence are: + + 0: start of stream (default): offset must not be negative + 1: current stream position + 2: end of stream; offset must not be positive + + Returns the new file position. + + Note that seeking is emulated, so depending on the arguments, + this operation may be extremely slow. + """ + self._check_can_read() + + # BufferedReader.seek() checks seekable + return self._buffer.seek(offset, whence) + + def peek(self, size=-1): + """Return buffered data without advancing the file position. + + Always returns at least one byte of data, unless at EOF. + The exact number of bytes returned is unspecified. + """ + # Relies on the undocumented fact that BufferedReader.peek() always + # returns at least one byte (except at EOF) + self._check_can_read() + return self._buffer.peek(size) + + def __next__(self): + ret = self._buffer.readline() + if ret: + return ret + raise StopIteration + + def tell(self): + """Return the current file position.""" + self._check_not_closed() + if self._mode == _MODE_READ: + return self._buffer.tell() + elif self._mode == _MODE_WRITE: + return self._pos + + def fileno(self): + """Return the file descriptor for the underlying file.""" + self._check_not_closed() + return self._fp.fileno() + + @property + def name(self): + self._check_not_closed() + return self._fp.name + + @property + def mode(self): + return 'wb' if self._mode == _MODE_WRITE else 'rb' + + @property + def closed(self): + """True if this file is closed.""" + return self._mode == _MODE_CLOSED + + def seekable(self): + """Return whether the file supports seeking.""" + return self.readable() and self._buffer.seekable() + + def readable(self): + """Return whether the file was opened for reading.""" + self._check_not_closed() + return self._mode == _MODE_READ + + def writable(self): + """Return whether the file was opened for writing.""" + self._check_not_closed() + return self._mode == _MODE_WRITE + + +# Copied from lzma module +def open( + filename, + mode="rb", + *, + level=None, + options=None, + zstd_dict=None, + encoding=None, + errors=None, + newline=None, +): + """Open a zstd compressed file in binary or text mode. + + filename can be either an actual file name (given as a str, bytes, or + PathLike object), in which case the named file is opened, or it can be an + existing file object to read from or write to. + + The mode parameter can be "r", "rb" (default), "w", "wb", "x", "xb", "a", + "ab" for binary mode, or "rt", "wt", "xt", "at" for text mode. + + The level, options, and zstd_dict parameters specify the settings the same + as ZstdFile. + + When using read mode (decompression), the options parameter is a dict + representing advanced decompression options. The level parameter is not + supported in this case. When using write mode (compression), only one of + level, an int representing the compression level, or options, a dict + representing advanced compression options, may be passed. In both modes, + zstd_dict is a ZstdDict instance containing a trained Zstandard dictionary. + + For binary mode, this function is equivalent to the ZstdFile constructor: + ZstdFile(filename, mode, ...). In this case, the encoding, errors and + newline parameters must not be provided. + + For text mode, an ZstdFile object is created, and wrapped in an + io.TextIOWrapper instance with the specified encoding, error handling + behavior, and line ending(s). + """ + + if "t" in mode: + if "b" in mode: + raise ValueError("Invalid mode: %r" % (mode,)) + else: + if encoding is not None: + raise ValueError("Argument 'encoding' not supported in binary mode") + if errors is not None: + raise ValueError("Argument 'errors' not supported in binary mode") + if newline is not None: + raise ValueError("Argument 'newline' not supported in binary mode") + + zstd_mode = mode.replace("t", "") + binary_file = ZstdFile( + filename, zstd_mode, level=level, options=options, zstd_dict=zstd_dict + ) + + if "t" in mode: + return io.TextIOWrapper(binary_file, encoding, errors, newline) + else: + return binary_file diff --git a/Lib/shutil.py b/Lib/shutil.py index 510ae8c6f22d59..ca0a2ea2f7fa8a 100644 --- a/Lib/shutil.py +++ b/Lib/shutil.py @@ -32,6 +32,13 @@ except ImportError: _LZMA_SUPPORTED = False +try: + from compression import zstd + del zstd + _ZSTD_SUPPORTED = True +except ImportError: + _ZSTD_SUPPORTED = False + _WINDOWS = os.name == 'nt' posix = nt = None if os.name == 'posix': @@ -1006,6 +1013,8 @@ def _make_tarball(base_name, base_dir, compress="gzip", verbose=0, dry_run=0, tar_compression = 'bz2' elif _LZMA_SUPPORTED and compress == 'xz': tar_compression = 'xz' + elif _ZSTD_SUPPORTED and compress == 'zst': + tar_compression = 'zst' else: raise ValueError("bad value for 'compress', or compression format not " "supported : {0}".format(compress)) @@ -1134,6 +1143,10 @@ def _make_zipfile(base_name, base_dir, verbose=0, dry_run=0, _ARCHIVE_FORMATS['xztar'] = (_make_tarball, [('compress', 'xz')], "xz'ed tar-file") +if _ZSTD_SUPPORTED: + _ARCHIVE_FORMATS['zstdtar'] = (_make_tarball, [('compress', 'zst')], + "zstd'ed tar-file") + def get_archive_formats(): """Returns a list of supported formats for archiving and unarchiving. @@ -1174,7 +1187,7 @@ def make_archive(base_name, format, root_dir=None, base_dir=None, verbose=0, 'base_name' is the name of the file to create, minus any format-specific extension; 'format' is the archive format: one of "zip", "tar", "gztar", - "bztar", or "xztar". Or any other registered format. + "bztar", "zstdtar", or "xztar". Or any other registered format. 'root_dir' is a directory that will be the root directory of the archive; ie. we typically chdir into 'root_dir' before creating the @@ -1359,6 +1372,10 @@ def _unpack_tarfile(filename, extract_dir, *, filter=None): _UNPACK_FORMATS['xztar'] = (['.tar.xz', '.txz'], _unpack_tarfile, [], "xz'ed tar-file") +if _ZSTD_SUPPORTED: + _UNPACK_FORMATS['zstdtar'] = (['.tar.zst', '.tzst'], _unpack_tarfile, [], + "zstd'ed tar-file") + def _find_unpack_format(filename): for name, info in _UNPACK_FORMATS.items(): for extension in info[0]: diff --git a/Lib/tarfile.py b/Lib/tarfile.py index 28581f3e7a2692..7e86614be50896 100644 --- a/Lib/tarfile.py +++ b/Lib/tarfile.py @@ -399,7 +399,17 @@ def __init__(self, name, mode, comptype, fileobj, bufsize, self.exception = lzma.LZMAError else: self.cmp = lzma.LZMACompressor(preset=preset) - + elif comptype == "zst": + try: + from compression import zstd + except ImportError: + raise CompressionError("compression.zstd module is not available") from None + if mode == "r": + self.dbuf = b"" + self.cmp = zstd.ZstdDecompressor() + self.exception = zstd.ZstdError + else: + self.cmp = zstd.ZstdCompressor() elif comptype != "tar": raise CompressionError("unknown compression type %r" % comptype) @@ -591,6 +601,8 @@ def getcomptype(self): return "bz2" elif self.buf.startswith((b"\x5d\x00\x00\x80", b"\xfd7zXZ")): return "xz" + elif self.buf.startswith(b"\x28\xb5\x2f\xfd"): + return "zst" else: return "tar" @@ -1817,11 +1829,13 @@ def open(cls, name=None, mode="r", fileobj=None, bufsize=RECORDSIZE, **kwargs): 'r:gz' open for reading with gzip compression 'r:bz2' open for reading with bzip2 compression 'r:xz' open for reading with lzma compression + 'r:zst' open for reading with zstd compression 'a' or 'a:' open for appending, creating the file if necessary 'w' or 'w:' open for writing without compression 'w:gz' open for writing with gzip compression 'w:bz2' open for writing with bzip2 compression 'w:xz' open for writing with lzma compression + 'w:zst' open for writing with zstd compression 'x' or 'x:' create a tarfile exclusively without compression, raise an exception if the file is already created @@ -1831,16 +1845,20 @@ def open(cls, name=None, mode="r", fileobj=None, bufsize=RECORDSIZE, **kwargs): if the file is already created 'x:xz' create an lzma compressed tarfile, raise an exception if the file is already created + 'x:zst' create a zstd compressed tarfile, raise an exception + if the file is already created 'r|*' open a stream of tar blocks with transparent compression 'r|' open an uncompressed stream of tar blocks for reading 'r|gz' open a gzip compressed stream of tar blocks 'r|bz2' open a bzip2 compressed stream of tar blocks 'r|xz' open an lzma compressed stream of tar blocks + 'r|zst' open a zstd compressed stream of tar blocks 'w|' open an uncompressed stream for writing 'w|gz' open a gzip compressed stream for writing 'w|bz2' open a bzip2 compressed stream for writing 'w|xz' open an lzma compressed stream for writing + 'w|zst' open a zstd compressed stream for writing """ if not name and not fileobj: @@ -2006,12 +2024,48 @@ def xzopen(cls, name, mode="r", fileobj=None, preset=None, **kwargs): t._extfileobj = False return t + @classmethod + def zstopen(cls, name, mode="r", fileobj=None, level=None, options=None, + zstd_dict=None, **kwargs): + """Open zstd compressed tar archive name for reading or writing. + Appending is not allowed. + """ + if mode not in ("r", "w", "x"): + raise ValueError("mode must be 'r', 'w' or 'x'") + + try: + from compression.zstd import ZstdFile, ZstdError + except ImportError: + raise CompressionError("compression.zstd module is not available") from None + + fileobj = ZstdFile( + fileobj or name, + mode, + level=level, + options=options, + zstd_dict=zstd_dict + ) + + try: + t = cls.taropen(name, mode, fileobj, **kwargs) + except (ZstdError, EOFError) as e: + fileobj.close() + if mode == 'r': + raise ReadError("not a zstd file") from e + raise + except: + fileobj.close() + raise + t._extfileobj = False + return t + # All *open() methods are registered here. OPEN_METH = { "tar": "taropen", # uncompressed tar "gz": "gzopen", # gzip compressed tar "bz2": "bz2open", # bzip2 compressed tar - "xz": "xzopen" # lzma compressed tar + "xz": "xzopen", # lzma compressed tar + "zst": "zstopen" # zstd compressed tar } #-------------------------------------------------------------------------- @@ -2963,6 +3017,9 @@ def main(): '.tbz': 'bz2', '.tbz2': 'bz2', '.tb2': 'bz2', + # zstd + '.zst': 'zst', + '.tzst': 'zst', } tar_mode = 'w:' + compressions[ext] if ext in compressions else 'w' tar_files = args.create diff --git a/Lib/test/support/__init__.py b/Lib/test/support/__init__.py index 041f1250003b68..c74c3a3190947b 100644 --- a/Lib/test/support/__init__.py +++ b/Lib/test/support/__init__.py @@ -33,7 +33,7 @@ "is_resource_enabled", "requires", "requires_freebsd_version", "requires_gil_enabled", "requires_linux_version", "requires_mac_ver", "check_syntax_error", - "requires_gzip", "requires_bz2", "requires_lzma", + "requires_gzip", "requires_bz2", "requires_lzma", "requires_zstd", "bigmemtest", "bigaddrspacetest", "cpython_only", "get_attribute", "requires_IEEE_754", "requires_zlib", "has_fork_support", "requires_fork", @@ -527,6 +527,13 @@ def requires_lzma(reason='requires lzma'): lzma = None return unittest.skipUnless(lzma, reason) +def requires_zstd(reason='requires zstd'): + try: + from compression import zstd + except ImportError: + zstd = None + return unittest.skipUnless(zstd, reason) + def has_no_debug_ranges(): try: import _testcapi diff --git a/Lib/test/test_shutil.py b/Lib/test/test_shutil.py index ed01163074a507..87991fbda4c7df 100644 --- a/Lib/test/test_shutil.py +++ b/Lib/test/test_shutil.py @@ -2153,6 +2153,10 @@ def test_unpack_archive_gztar(self): def test_unpack_archive_bztar(self): self.check_unpack_tarball('bztar') + @support.requires_zstd() + def test_unpack_archive_zstdtar(self): + self.check_unpack_tarball('zstdtar') + @support.requires_lzma() @unittest.skipIf(AIX and not _maxdataOK(), "AIX MAXDATA must be 0x20000000 or larger") def test_unpack_archive_xztar(self): diff --git a/Lib/test/test_tarfile.py b/Lib/test/test_tarfile.py index fcbaf854cc294f..2d9649237a9382 100644 --- a/Lib/test/test_tarfile.py +++ b/Lib/test/test_tarfile.py @@ -38,6 +38,10 @@ import lzma except ImportError: lzma = None +try: + from compression import zstd +except ImportError: + zstd = None def sha256sum(data): return sha256(data).hexdigest() @@ -48,6 +52,7 @@ def sha256sum(data): gzipname = os.path.join(TEMPDIR, "testtar.tar.gz") bz2name = os.path.join(TEMPDIR, "testtar.tar.bz2") xzname = os.path.join(TEMPDIR, "testtar.tar.xz") +zstname = os.path.join(TEMPDIR, "testtar.tar.zst") tmpname = os.path.join(TEMPDIR, "tmp.tar") dotlessname = os.path.join(TEMPDIR, "testtar") @@ -90,6 +95,12 @@ class LzmaTest: open = lzma.LZMAFile if lzma else None taropen = tarfile.TarFile.xzopen +@support.requires_zstd() +class ZstdTest: + tarname = zstname + suffix = 'zst' + open = zstd.ZstdFile if zstd else None + taropen = tarfile.TarFile.zstopen class ReadTest(TarTest): @@ -271,6 +282,8 @@ class Bz2UstarReadTest(Bz2Test, UstarReadTest): class LzmaUstarReadTest(LzmaTest, UstarReadTest): pass +class ZstdUstarReadTest(ZstdTest, UstarReadTest): + pass class ListTest(ReadTest, unittest.TestCase): @@ -375,6 +388,8 @@ class Bz2ListTest(Bz2Test, ListTest): class LzmaListTest(LzmaTest, ListTest): pass +class ZstdListTest(ZstdTest, ListTest): + pass class CommonReadTest(ReadTest): @@ -837,6 +852,8 @@ class Bz2MiscReadTest(Bz2Test, MiscReadTestBase, unittest.TestCase): class LzmaMiscReadTest(LzmaTest, MiscReadTestBase, unittest.TestCase): pass +class ZstdMiscReadTest(ZstdTest, MiscReadTestBase, unittest.TestCase): + pass class StreamReadTest(CommonReadTest, unittest.TestCase): @@ -909,6 +926,9 @@ class Bz2StreamReadTest(Bz2Test, StreamReadTest): class LzmaStreamReadTest(LzmaTest, StreamReadTest): pass +class ZstdStreamReadTest(ZstdTest, StreamReadTest): + pass + class TarStreamModeReadTest(StreamModeTest, unittest.TestCase): def test_stream_mode_no_cache(self): @@ -925,6 +945,9 @@ class Bz2StreamModeReadTest(Bz2Test, TarStreamModeReadTest): class LzmaStreamModeReadTest(LzmaTest, TarStreamModeReadTest): pass +class ZstdStreamModeReadTest(ZstdTest, TarStreamModeReadTest): + pass + class DetectReadTest(TarTest, unittest.TestCase): def _testfunc_file(self, name, mode): try: @@ -986,6 +1009,8 @@ def test_detect_stream_bz2(self): class LzmaDetectReadTest(LzmaTest, DetectReadTest): pass +class ZstdDetectReadTest(ZstdTest, DetectReadTest): + pass class GzipBrokenHeaderCorrectException(GzipTest, unittest.TestCase): """ @@ -1666,6 +1691,8 @@ class Bz2WriteTest(Bz2Test, WriteTest): class LzmaWriteTest(LzmaTest, WriteTest): pass +class ZstdWriteTest(ZstdTest, WriteTest): + pass class StreamWriteTest(WriteTestBase, unittest.TestCase): @@ -1727,6 +1754,9 @@ class Bz2StreamWriteTest(Bz2Test, StreamWriteTest): class LzmaStreamWriteTest(LzmaTest, StreamWriteTest): decompressor = lzma.LZMADecompressor if lzma else None +class ZstdStreamWriteTest(ZstdTest, StreamWriteTest): + decompressor = zstd.ZstdDecompressor if zstd else None + class _CompressedWriteTest(TarTest): # This is not actually a standalone test. # It does not inherit WriteTest because it only makes sense with gz,bz2 @@ -2042,6 +2072,14 @@ def test_create_with_preset(self): tobj.add(self.file_path) +class ZstdCreateTest(ZstdTest, CreateTest): + + # Unlike gz and bz2, zstd uses the level keyword instead of compresslevel. + # It does not allow for level to be specified when reading. + def test_create_with_level(self): + with tarfile.open(tmpname, self.mode, level=1) as tobj: + tobj.add(self.file_path) + class CreateWithXModeTest(CreateTest): prefix = "x" @@ -2523,6 +2561,8 @@ class Bz2AppendTest(Bz2Test, AppendTestBase, unittest.TestCase): class LzmaAppendTest(LzmaTest, AppendTestBase, unittest.TestCase): pass +class ZstdAppendTest(ZstdTest, AppendTestBase, unittest.TestCase): + pass class LimitsTest(unittest.TestCase): @@ -2835,7 +2875,7 @@ def test_create_command_compressed(self): support.findfile('tokenize_tests-no-coding-cookie-' 'and-utf8-bom-sig-only.txt', subdir='tokenizedata')] - for filetype in (GzipTest, Bz2Test, LzmaTest): + for filetype in (GzipTest, Bz2Test, LzmaTest, ZstdTest): if not filetype.open: continue try: @@ -4257,7 +4297,7 @@ def setUpModule(): data = fobj.read() # Create compressed tarfiles. - for c in GzipTest, Bz2Test, LzmaTest: + for c in GzipTest, Bz2Test, LzmaTest, ZstdTest: if c.open: os_helper.unlink(c.tarname) testtarnames.append(c.tarname) diff --git a/Lib/test/test_zipfile/test_core.py b/Lib/test/test_zipfile/test_core.py index 4c9d9f4b56235d..ae898150658565 100644 --- a/Lib/test/test_zipfile/test_core.py +++ b/Lib/test/test_zipfile/test_core.py @@ -23,7 +23,7 @@ from test.support import script_helper, os_helper from test.support import ( findfile, requires_zlib, requires_bz2, requires_lzma, - captured_stdout, captured_stderr, requires_subprocess, + requires_zstd, captured_stdout, captured_stderr, requires_subprocess, cpython_only ) from test.support.os_helper import ( @@ -702,6 +702,10 @@ class LzmaTestsWithSourceFile(AbstractTestsWithSourceFile, unittest.TestCase): compression = zipfile.ZIP_LZMA +@requires_zstd() +class ZstdTestsWithSourceFile(AbstractTestsWithSourceFile, + unittest.TestCase): + compression = zipfile.ZIP_ZSTANDARD class AbstractTestZip64InSmallFiles: # These tests test the ZIP64 functionality without using large files, @@ -1279,6 +1283,10 @@ class LzmaTestZip64InSmallFiles(AbstractTestZip64InSmallFiles, unittest.TestCase): compression = zipfile.ZIP_LZMA +@requires_zstd() +class ZstdTestZip64InSmallFiles(AbstractTestZip64InSmallFiles, + unittest.TestCase): + compression = zipfile.ZIP_ZSTANDARD class AbstractWriterTests: @@ -1348,6 +1356,9 @@ class Bzip2WriterTests(AbstractWriterTests, unittest.TestCase): class LzmaWriterTests(AbstractWriterTests, unittest.TestCase): compression = zipfile.ZIP_LZMA +@requires_zstd() +class ZstdWriterTests(AbstractWriterTests, unittest.TestCase): + compression = zipfile.ZIP_ZSTANDARD class PyZipFileTests(unittest.TestCase): def assertCompiledIn(self, name, namelist): @@ -2678,6 +2689,17 @@ class LzmaBadCrcTests(AbstractBadCrcTests, unittest.TestCase): b'ePK\x05\x06\x00\x00\x00\x00\x01\x00\x01\x003\x00\x00' b'\x00>\x00\x00\x00\x00\x00') +@requires_zstd() +class ZstdBadCrcTests(AbstractBadCrcTests, unittest.TestCase): + compression = zipfile.ZIP_ZSTANDARD + zip_with_bad_crc = ( + b'PK\x03\x04?\x00\x00\x00]\x00\x00\x00!\x00V\xb1\x17J\x14\x00' + b'\x00\x00\x0b\x00\x00\x00\x05\x00\x00\x00afile(\xb5/\xfd\x00' + b'XY\x00\x00Hello WorldPK\x01\x02?\x03?\x00\x00\x00]\x00\x00\x00' + b'!\x00V\xb0\x17J\x14\x00\x00\x00\x0b\x00\x00\x00\x05\x00\x00\x00' + b'\x00\x00\x00\x00\x00\x00\x00\x00\x80\x01\x00\x00\x00\x00afilePK' + b'\x05\x06\x00\x00\x00\x00\x01\x00\x01\x003\x00\x00\x007\x00\x00\x00' + b'\x00\x00') class DecryptionTests(unittest.TestCase): """Check that ZIP decryption works. Since the library does not @@ -2905,6 +2927,10 @@ class LzmaTestsWithRandomBinaryFiles(AbstractTestsWithRandomBinaryFiles, unittest.TestCase): compression = zipfile.ZIP_LZMA +@requires_zstd() +class ZstdTestsWithRandomBinaryFiles(AbstractTestsWithRandomBinaryFiles, + unittest.TestCase): + compression = zipfile.ZIP_ZSTANDARD # Provide the tell() method but not seek() class Tellable: diff --git a/Lib/test/test_zstd/__init__.py b/Lib/test/test_zstd/__init__.py new file mode 100644 index 00000000000000..4b16ecc31156a5 --- /dev/null +++ b/Lib/test/test_zstd/__init__.py @@ -0,0 +1,5 @@ +import os +from test.support import load_package_tests + +def load_tests(*args): + return load_package_tests(os.path.dirname(__file__), *args) diff --git a/Lib/test/test_zstd/__main__.py b/Lib/test/test_zstd/__main__.py new file mode 100644 index 00000000000000..e25ac946edffe4 --- /dev/null +++ b/Lib/test/test_zstd/__main__.py @@ -0,0 +1,7 @@ +import unittest + +from . import load_tests # noqa: F401 + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_zstd/test_core.py b/Lib/test/test_zstd/test_core.py new file mode 100644 index 00000000000000..07463c519eb83a --- /dev/null +++ b/Lib/test/test_zstd/test_core.py @@ -0,0 +1,2693 @@ +import array +import gc +import io +import pathlib +import pickle +import random +import builtins +import re +import os +import unittest +import tempfile +import threading + +from compression._common import _streams + +from test.support.import_helper import import_module +from test.support import threading_helper +from test.support import _1M +from test.support import Py_GIL_DISABLED + +zstd = import_module("compression.zstd") +_zstd = import_module("_zstd") +from compression.zstd import ( + zstdfile, + compress, + decompress, + ZstdCompressor, + ZstdDecompressor, + ZstdDict, + ZstdError, + zstd_version, + zstd_version_info, + compressionLevel_values, + get_frame_info, + get_frame_size, + finalize_dict, + train_dict, + CParameter, + DParameter, + Strategy, + ZstdFile, + zstd_support_multithread, +) +from compression.zstd.zstdfile import open + +_1K = 1024 +_130_1K = 130 * _1K +DICT_SIZE1 = 3*_1K + +DAT_130K_D = None +DAT_130K_C = None + +DECOMPRESSED_DAT = None +COMPRESSED_DAT = None + +DECOMPRESSED_100_PLUS_32KB = None +COMPRESSED_100_PLUS_32KB = None + +SKIPPABLE_FRAME = None + +THIS_FILE_BYTES = None +THIS_FILE_STR = None +COMPRESSED_THIS_FILE = None + +COMPRESSED_BOGUS = None + +SAMPLES = None + +TRAINED_DICT = None + +KB = 1024 +MB = 1024*1024 + +def setUpModule(): + # uncompressed size 130KB, more than a zstd block. + # with a frame epilogue, 4 bytes checksum. + global DAT_130K_D + DAT_130K_D = bytes([random.randint(0, 127) for _ in range(130*1024)]) + + global DAT_130K_C + DAT_130K_C = compress(DAT_130K_D, options={CParameter.checksumFlag:1}) + + global DECOMPRESSED_DAT + DECOMPRESSED_DAT = b'abcdefg123456' * 1000 + + global COMPRESSED_DAT + COMPRESSED_DAT = compress(DECOMPRESSED_DAT) + + global DECOMPRESSED_100_PLUS_32KB + DECOMPRESSED_100_PLUS_32KB = b'a' * (100 + 32*1024) + + global COMPRESSED_100_PLUS_32KB + COMPRESSED_100_PLUS_32KB = compress(DECOMPRESSED_100_PLUS_32KB) + + global SKIPPABLE_FRAME + SKIPPABLE_FRAME = (0x184D2A50).to_bytes(4, byteorder='little') + \ + (32*1024).to_bytes(4, byteorder='little') + \ + b'a' * (32*1024) + + global THIS_FILE_BYTES, THIS_FILE_STR + with builtins.open(os.path.abspath(__file__), 'rb') as f: + THIS_FILE_BYTES = f.read() + THIS_FILE_BYTES = re.sub(rb'\r?\n', rb'\n', THIS_FILE_BYTES) + THIS_FILE_STR = THIS_FILE_BYTES.decode('utf-8') + + global COMPRESSED_THIS_FILE + COMPRESSED_THIS_FILE = compress(THIS_FILE_BYTES) + + global COMPRESSED_BOGUS + COMPRESSED_BOGUS = DECOMPRESSED_DAT + + # dict data + words = [b'red', b'green', b'yellow', b'black', b'withe', b'blue', + b'lilac', b'purple', b'navy', b'glod', b'silver', b'olive', + b'dog', b'cat', b'tiger', b'lion', b'fish', b'bird'] + lst = [] + for i in range(300): + sample = [b'%s = %d' % (random.choice(words), random.randrange(100)) + for j in range(20)] + sample = b'\n'.join(sample) + + lst.append(sample) + global SAMPLES + SAMPLES = lst + assert len(SAMPLES) > 10 + + global TRAINED_DICT + TRAINED_DICT = train_dict(SAMPLES, 3*1024) + assert len(TRAINED_DICT.dict_content) <= 3*1024 + + +class FunctionsTestCase(unittest.TestCase): + + def test_version(self): + s = ".".join((str(i) for i in zstd_version_info)) + self.assertEqual(s, zstd_version) + + def test_compressionLevel_values(self): + self.assertIs(type(compressionLevel_values.default), int) + self.assertIs(type(compressionLevel_values.min), int) + self.assertIs(type(compressionLevel_values.max), int) + self.assertLess(compressionLevel_values.min, compressionLevel_values.max) + + def test_roundtrip_default(self): + raw_dat = THIS_FILE_BYTES[: len(THIS_FILE_BYTES) // 6] + dat1 = compress(raw_dat) + dat2 = decompress(dat1) + self.assertEqual(dat2, raw_dat) + + def test_roundtrip_level(self): + raw_dat = THIS_FILE_BYTES[: len(THIS_FILE_BYTES) // 6] + _default, minv, maxv = compressionLevel_values + + for level in range(max(-20, minv), maxv + 1): + dat1 = compress(raw_dat, level) + dat2 = decompress(dat1) + self.assertEqual(dat2, raw_dat) + + def test_get_frame_info(self): + # no dict + info = get_frame_info(COMPRESSED_100_PLUS_32KB[:20]) + self.assertEqual(info.decompressed_size, 32 * 1024 + 100) + self.assertEqual(info.dictionary_id, 0) + + # use dict + dat = compress(b"a" * 345, zstd_dict=TRAINED_DICT) + info = get_frame_info(dat) + self.assertEqual(info.decompressed_size, 345) + self.assertEqual(info.dictionary_id, TRAINED_DICT.dict_id) + + with self.assertRaisesRegex(ZstdError, "not less than the frame header"): + get_frame_info(b"aaaaaaaaaaaaaa") + + def test_get_frame_size(self): + size = get_frame_size(COMPRESSED_100_PLUS_32KB) + self.assertEqual(size, len(COMPRESSED_100_PLUS_32KB)) + + with self.assertRaisesRegex(ZstdError, "not less than this complete frame"): + get_frame_size(b"aaaaaaaaaaaaaa") + + def test_decompress_2x130_1K(self): + decompressed_size = get_frame_info(DAT_130K_C).decompressed_size + self.assertEqual(decompressed_size, _130_1K) + + dat = decompress(DAT_130K_C + DAT_130K_C) + self.assertEqual(len(dat), 2 * _130_1K) + +class ClassShapeTestCase(unittest.TestCase): + + def test_ZstdCompressor(self): + # class attributes + ZstdCompressor.CONTINUE + ZstdCompressor.FLUSH_BLOCK + ZstdCompressor.FLUSH_FRAME + + # method & me_1Mer + ZstdCompressor() + ZstdCompressor(12, zstd_dict=TRAINED_DICT) + c = ZstdCompressor(level=2, zstd_dict=TRAINED_DICT) + + c.compress(b"123456") + c.compress(b"123456", ZstdCompressor.CONTINUE) + c.compress(data=b"123456", mode=c.CONTINUE) + + c.flush() + c.flush(ZstdCompressor.FLUSH_BLOCK) + c.flush(mode=c.FLUSH_FRAME) + + c.last_mode + + # decompressor method & me_1Mer + with self.assertRaises(AttributeError): + c.decompress(b"") + with self.assertRaises(AttributeError): + c.at_frame_edge + with self.assertRaises(AttributeError): + c.eof + with self.assertRaises(AttributeError): + c.needs_input + + # read only attribute + with self.assertRaises(AttributeError): + c.last_mode = ZstdCompressor.FLUSH_BLOCK + + # name + self.assertIn(".ZstdCompressor", str(type(c))) + + # doesn't support pickle + with self.assertRaises(TypeError): + pickle.dumps(c) + + # supports subclass + class SubClass(ZstdCompressor): + pass + + def test_Decompressor(self): + # method & me_1Mer + ZstdDecompressor() + ZstdDecompressor(TRAINED_DICT, {}) + d = ZstdDecompressor(zstd_dict=TRAINED_DICT, options={}) + + d.decompress(b"") + d.decompress(b"", 100) + d.decompress(data=b"", max_length=100) + + d.eof + d.needs_input + d.unused_data + + # ZstdCompressor attributes + with self.assertRaises(AttributeError): + d.CONTINUE + with self.assertRaises(AttributeError): + d.FLUSH_BLOCK + with self.assertRaises(AttributeError): + d.FLUSH_FRAME + with self.assertRaises(AttributeError): + d.compress(b"") + with self.assertRaises(AttributeError): + d.flush() + + # read only attributes + with self.assertRaises(AttributeError): + d.eof = True + with self.assertRaises(AttributeError): + d.needs_input = True + with self.assertRaises(AttributeError): + d.unused_data = b"" + + # name + self.assertIn(".ZstdDecompressor", str(type(d))) + + # doesn't support pickle + with self.assertRaises(TypeError): + pickle.dumps(d) + + # supports subclass + class SubClass(ZstdDecompressor): + pass + + def test_ZstdDict(self): + ZstdDict(b"12345678", True) + zd = ZstdDict(b"12345678", is_raw=True) + + self.assertEqual(type(zd.dict_content), bytes) + self.assertEqual(zd.dict_id, 0) + self.assertEqual(zd.as_digested_dict[1], 0) + self.assertEqual(zd.as_undigested_dict[1], 1) + self.assertEqual(zd.as_prefix[1], 2) + + # name + self.assertIn(".ZstdDict", str(type(zd))) + + # doesn't support pickle + with self.assertRaisesRegex(TypeError, r"cannot pickle"): + pickle.dumps(zd) + with self.assertRaisesRegex(TypeError, r"cannot pickle"): + pickle.dumps(zd.as_prefix) + + # supports subclass + class SubClass(ZstdDict): + pass + + def test_Strategy(self): + # class attributes + Strategy.fast + Strategy.dfast + Strategy.greedy + Strategy.lazy + Strategy.lazy2 + Strategy.btlazy2 + Strategy.btopt + Strategy.btultra + Strategy.btultra2 + + def test_CParameter(self): + CParameter.compressionLevel + CParameter.windowLog + CParameter.hashLog + CParameter.chainLog + CParameter.searchLog + CParameter.minMatch + CParameter.targetLength + CParameter.strategy + with self.assertRaises(NotImplementedError): + CParameter.targetCBlockSize + + CParameter.enableLongDistanceMatching + CParameter.ldmHashLog + CParameter.ldmMinMatch + CParameter.ldmBucketSizeLog + CParameter.ldmHashRateLog + + CParameter.contentSizeFlag + CParameter.checksumFlag + CParameter.dictIDFlag + + CParameter.nbWorkers + CParameter.jobSize + CParameter.overlapLog + + t = CParameter.windowLog.bounds() + self.assertEqual(len(t), 2) + self.assertEqual(type(t[0]), int) + self.assertEqual(type(t[1]), int) + + def test_DParameter(self): + DParameter.windowLogMax + + t = DParameter.windowLogMax.bounds() + self.assertEqual(len(t), 2) + self.assertEqual(type(t[0]), int) + self.assertEqual(type(t[1]), int) + + def test_zstderror_pickle(self): + try: + decompress(b"invalid data") + except Exception as e: + s = pickle.dumps(e) + obj = pickle.loads(s) + self.assertEqual(type(obj), ZstdError) + else: + self.assertFalse(True, "unreachable code path") + + def test_ZstdFile_extend(self): + # These classes and variables can be used to extend ZstdFile, + # so pin them down. + self.assertTrue(issubclass(ZstdFile, io.BufferedIOBase)) + self.assertIs(ZstdFile._READER_CLASS, _streams.DecompressReader) + + # mode + self.assertEqual(zstdfile._MODE_CLOSED, 0) + self.assertEqual(zstdfile._MODE_READ, 1) + self.assertEqual(zstdfile._MODE_WRITE, 2) + + +class CompressorTestCase(unittest.TestCase): + + def test_simple_compress_bad_args(self): + # ZstdCompressor + self.assertRaises(TypeError, ZstdCompressor, []) + self.assertRaises(TypeError, ZstdCompressor, level=3.14) + self.assertRaises(TypeError, ZstdCompressor, level="abc") + self.assertRaises(TypeError, ZstdCompressor, options=b"abc") + + self.assertRaises(TypeError, ZstdCompressor, zstd_dict=123) + self.assertRaises(TypeError, ZstdCompressor, zstd_dict=b"abcd1234") + self.assertRaises(TypeError, ZstdCompressor, zstd_dict={1: 2, 3: 4}) + + with self.assertRaises(ValueError): + ZstdCompressor(2**31) + with self.assertRaises(ValueError): + ZstdCompressor(options={2**31: 100}) + + with self.assertRaises(ZstdError): + ZstdCompressor(options={CParameter.windowLog: 100}) + with self.assertRaises(ZstdError): + ZstdCompressor(options={3333: 100}) + + # Method bad arguments + zc = ZstdCompressor() + self.assertRaises(TypeError, zc.compress) + self.assertRaises((TypeError, ValueError), zc.compress, b"foo", b"bar") + self.assertRaises(TypeError, zc.compress, "str") + self.assertRaises((TypeError, ValueError), zc.flush, b"foo") + self.assertRaises(TypeError, zc.flush, b"blah", 1) + + self.assertRaises(ValueError, zc.compress, b'', -1) + self.assertRaises(ValueError, zc.compress, b'', 3) + self.assertRaises(ValueError, zc.flush, zc.CONTINUE) # 0 + self.assertRaises(ValueError, zc.flush, 3) + + zc.compress(b'') + zc.compress(b'', zc.CONTINUE) + zc.compress(b'', zc.FLUSH_BLOCK) + zc.compress(b'', zc.FLUSH_FRAME) + empty = zc.flush() + zc.flush(zc.FLUSH_BLOCK) + zc.flush(zc.FLUSH_FRAME) + + def test_compress_parameters(self): + d = {CParameter.compressionLevel : 10, + + CParameter.windowLog : 12, + CParameter.hashLog : 10, + CParameter.chainLog : 12, + CParameter.searchLog : 12, + CParameter.minMatch : 4, + CParameter.targetLength : 12, + CParameter.strategy : Strategy.lazy, + + CParameter.enableLongDistanceMatching : 1, + CParameter.ldmHashLog : 12, + CParameter.ldmMinMatch : 11, + CParameter.ldmBucketSizeLog : 5, + CParameter.ldmHashRateLog : 12, + + CParameter.contentSizeFlag : 1, + CParameter.checksumFlag : 1, + CParameter.dictIDFlag : 0, + + CParameter.nbWorkers : 2 if zstd_support_multithread else 0, + CParameter.jobSize : 5*_1M if zstd_support_multithread else 0, + CParameter.overlapLog : 9 if zstd_support_multithread else 0, + } + ZstdCompressor(options=d) + + # larger than signed int, ValueError + d1 = d.copy() + d1[CParameter.ldmBucketSizeLog] = 2**31 + self.assertRaises(ValueError, ZstdCompressor, d1) + + # clamp compressionLevel + compress(b'', compressionLevel_values.max+1) + compress(b'', compressionLevel_values.min-1) + + compress(b'', {CParameter.compressionLevel:compressionLevel_values.max+1}) + compress(b'', {CParameter.compressionLevel:compressionLevel_values.min-1}) + + # zstd lib doesn't support MT compression + if not zstd_support_multithread: + with self.assertRaises(ZstdError): + ZstdCompressor({CParameter.nbWorkers:4}) + with self.assertRaises(ZstdError): + ZstdCompressor({CParameter.jobSize:4}) + with self.assertRaises(ZstdError): + ZstdCompressor({CParameter.overlapLog:4}) + + # out of bounds error msg + option = {CParameter.windowLog:100} + with self.assertRaisesRegex(ZstdError, + (r'Error when setting zstd compression parameter "windowLog", ' + r'it should \d+ <= value <= \d+, provided value is 100\. ' + r'\(zstd v\d\.\d\.\d, (?:32|64)-bit build\)')): + compress(b'', option) + + def test_unknown_compression_parameter(self): + KEY = 100001234 + option = {CParameter.compressionLevel: 10, + KEY: 200000000} + pattern = r'Zstd compression parameter.*?"unknown parameter \(key %d\)"' \ + % KEY + with self.assertRaisesRegex(ZstdError, pattern): + ZstdCompressor(option) + + @unittest.skipIf(True,#not zstd_support_multithread, + "zstd build doesn't support multi-threaded compression") + def test_zstd_multithread_compress(self): + size = 40*_1M + b = THIS_FILE_BYTES * (size // len(THIS_FILE_BYTES)) + + options = {CParameter.compressionLevel : 4, + CParameter.nbWorkers : 2} + + # compress() + dat1 = compress(b, options=options) + dat2 = decompress(dat1) + self.assertEqual(dat2, b) + + # ZstdCompressor + c = ZstdCompressor(options=options) + dat1 = c.compress(b, c.CONTINUE) + dat2 = c.compress(b, c.FLUSH_BLOCK) + dat3 = c.compress(b, c.FLUSH_FRAME) + dat4 = decompress(dat1+dat2+dat3) + self.assertEqual(dat4, b * 3) + + # ZstdFile + with ZstdFile(io.BytesIO(), 'w', + options=options) as f: + f.write(b) + + def test_compress_flushblock(self): + point = len(THIS_FILE_BYTES) // 2 + + c = ZstdCompressor() + self.assertEqual(c.last_mode, c.FLUSH_FRAME) + dat1 = c.compress(THIS_FILE_BYTES[:point]) + self.assertEqual(c.last_mode, c.CONTINUE) + dat1 += c.compress(THIS_FILE_BYTES[point:], c.FLUSH_BLOCK) + self.assertEqual(c.last_mode, c.FLUSH_BLOCK) + dat2 = c.flush() + pattern = r"Compressed data ended before the end-of-stream marker" + with self.assertRaisesRegex(ZstdError, pattern): + decompress(dat1) + + dat3 = decompress(dat1 + dat2) + + self.assertEqual(dat3, THIS_FILE_BYTES) + + def test_compress_flushframe(self): + # test compress & decompress + point = len(THIS_FILE_BYTES) // 2 + + c = ZstdCompressor() + + dat1 = c.compress(THIS_FILE_BYTES[:point]) + self.assertEqual(c.last_mode, c.CONTINUE) + + dat1 += c.compress(THIS_FILE_BYTES[point:], c.FLUSH_FRAME) + self.assertEqual(c.last_mode, c.FLUSH_FRAME) + + nt = get_frame_info(dat1) + self.assertEqual(nt.decompressed_size, None) # no content size + + dat2 = decompress(dat1) + + self.assertEqual(dat2, THIS_FILE_BYTES) + + # single .FLUSH_FRAME mode has content size + c = ZstdCompressor() + dat = c.compress(THIS_FILE_BYTES, mode=c.FLUSH_FRAME) + self.assertEqual(c.last_mode, c.FLUSH_FRAME) + + nt = get_frame_info(dat) + self.assertEqual(nt.decompressed_size, len(THIS_FILE_BYTES)) + + def test_compress_empty(self): + # output empty content frame + self.assertNotEqual(compress(b''), b'') + + c = ZstdCompressor() + self.assertNotEqual(c.compress(b'', c.FLUSH_FRAME), b'') + +class DecompressorTestCase(unittest.TestCase): + + def test_simple_decompress_bad_args(self): + # ZstdDecompressor + self.assertRaises(TypeError, ZstdDecompressor, ()) + self.assertRaises(TypeError, ZstdDecompressor, zstd_dict=123) + self.assertRaises(TypeError, ZstdDecompressor, zstd_dict=b'abc') + self.assertRaises(TypeError, ZstdDecompressor, zstd_dict={1:2, 3:4}) + + self.assertRaises(TypeError, ZstdDecompressor, options=123) + self.assertRaises(TypeError, ZstdDecompressor, options='abc') + self.assertRaises(TypeError, ZstdDecompressor, options=b'abc') + + with self.assertRaises(ValueError): + ZstdDecompressor(options={2**31 : 100}) + + with self.assertRaises(ZstdError): + ZstdDecompressor(options={DParameter.windowLogMax:100}) + with self.assertRaises(ZstdError): + ZstdDecompressor(options={3333 : 100}) + + empty = compress(b'') + lzd = ZstdDecompressor() + self.assertRaises(TypeError, lzd.decompress) + self.assertRaises(TypeError, lzd.decompress, b"foo", b"bar") + self.assertRaises(TypeError, lzd.decompress, "str") + lzd.decompress(empty) + + def test_decompress_parameters(self): + d = {DParameter.windowLogMax : 15} + ZstdDecompressor(options=d) + + # larger than signed int, ValueError + d1 = d.copy() + d1[DParameter.windowLogMax] = 2**31 + self.assertRaises(ValueError, ZstdDecompressor, None, d1) + + # out of bounds error msg + options = {DParameter.windowLogMax:100} + with self.assertRaisesRegex(ZstdError, + (r'Error when setting zstd decompression parameter "windowLogMax", ' + r'it should \d+ <= value <= \d+, provided value is 100\. ' + r'\(zstd v\d\.\d\.\d, (?:32|64)-bit build\)')): + decompress(b'', options=options) + + def test_unknown_decompression_parameter(self): + KEY = 100001234 + options = {DParameter.windowLogMax: DParameter.windowLogMax.bounds()[1], + KEY: 200000000} + pattern = r'Zstd decompression parameter.*?"unknown parameter \(key %d\)"' \ + % KEY + with self.assertRaisesRegex(ZstdError, pattern): + ZstdDecompressor(options=options) + + def test_decompress_epilogue_flags(self): + # DAT_130K_C has a 4 bytes checksum at frame epilogue + + # full unlimited + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C) + self.assertEqual(len(dat), _130_1K) + self.assertFalse(d.needs_input) + + with self.assertRaises(EOFError): + dat = d.decompress(b'') + + # full limited + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C, _130_1K) + self.assertEqual(len(dat), _130_1K) + self.assertFalse(d.needs_input) + + with self.assertRaises(EOFError): + dat = d.decompress(b'', 0) + + # [:-4] unlimited + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C[:-4]) + self.assertEqual(len(dat), _130_1K) + self.assertTrue(d.needs_input) + + dat = d.decompress(b'') + self.assertEqual(len(dat), 0) + self.assertTrue(d.needs_input) + + # [:-4] limited + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C[:-4], _130_1K) + self.assertEqual(len(dat), _130_1K) + self.assertFalse(d.needs_input) + + dat = d.decompress(b'', 0) + self.assertEqual(len(dat), 0) + self.assertFalse(d.needs_input) + + # [:-3] unlimited + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C[:-3]) + self.assertEqual(len(dat), _130_1K) + self.assertTrue(d.needs_input) + + dat = d.decompress(b'') + self.assertEqual(len(dat), 0) + self.assertTrue(d.needs_input) + + # [:-3] limited + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C[:-3], _130_1K) + self.assertEqual(len(dat), _130_1K) + self.assertFalse(d.needs_input) + + dat = d.decompress(b'', 0) + self.assertEqual(len(dat), 0) + self.assertFalse(d.needs_input) + + # [:-1] unlimited + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C[:-1]) + self.assertEqual(len(dat), _130_1K) + self.assertTrue(d.needs_input) + + dat = d.decompress(b'') + self.assertEqual(len(dat), 0) + self.assertTrue(d.needs_input) + + # [:-1] limited + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C[:-1], _130_1K) + self.assertEqual(len(dat), _130_1K) + self.assertFalse(d.needs_input) + + dat = d.decompress(b'', 0) + self.assertEqual(len(dat), 0) + self.assertFalse(d.needs_input) + + def test_decompressor_arg(self): + zd = ZstdDict(b'12345678', True) + + with self.assertRaises(TypeError): + d = ZstdDecompressor(zstd_dict={}) + + with self.assertRaises(TypeError): + d = ZstdDecompressor(options=zd) + + ZstdDecompressor() + ZstdDecompressor(zd, {}) + ZstdDecompressor(zstd_dict=zd, options={DParameter.windowLogMax:25}) + + def test_decompressor_1(self): + # empty + d = ZstdDecompressor() + dat = d.decompress(b'') + + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + + # 130_1K full + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C) + + self.assertEqual(len(dat), _130_1K) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + + # 130_1K full, limit output + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C, _130_1K) + + self.assertEqual(len(dat), _130_1K) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + + # 130_1K, without 4 bytes checksum + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C[:-4]) + + self.assertEqual(len(dat), _130_1K) + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + + # above, limit output + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C[:-4], _130_1K) + + self.assertEqual(len(dat), _130_1K) + self.assertFalse(d.eof) + self.assertFalse(d.needs_input) + + # full, unused_data + TRAIL = b'89234893abcd' + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C + TRAIL, _130_1K) + + self.assertEqual(len(dat), _130_1K) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, TRAIL) + + def test_decompressor_chunks_read_300(self): + TRAIL = b'89234893abcd' + DAT = DAT_130K_C + TRAIL + d = ZstdDecompressor() + + bi = io.BytesIO(DAT) + lst = [] + while True: + if d.needs_input: + dat = bi.read(300) + if not dat: + break + else: + raise Exception('should not get here') + + ret = d.decompress(dat) + lst.append(ret) + if d.eof: + break + + ret = b''.join(lst) + + self.assertEqual(len(ret), _130_1K) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data + bi.read(), TRAIL) + + def test_decompressor_chunks_read_3(self): + TRAIL = b'89234893' + DAT = DAT_130K_C + TRAIL + d = ZstdDecompressor() + + bi = io.BytesIO(DAT) + lst = [] + while True: + if d.needs_input: + dat = bi.read(3) + if not dat: + break + else: + dat = b'' + + ret = d.decompress(dat, 1) + lst.append(ret) + if d.eof: + break + + ret = b''.join(lst) + + self.assertEqual(len(ret), _130_1K) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data + bi.read(), TRAIL) + + + def test_decompress_empty(self): + with self.assertRaises(ZstdError): + decompress(b'') + + d = ZstdDecompressor() + self.assertEqual(d.decompress(b''), b'') + self.assertFalse(d.eof) + + def test_decompress_empty_content_frame(self): + DAT = compress(b'') + # decompress + self.assertGreaterEqual(len(DAT), 4) + self.assertEqual(decompress(DAT), b'') + + with self.assertRaises(ZstdError): + decompress(DAT[:-1]) + + # ZstdDecompressor + d = ZstdDecompressor() + dat = d.decompress(DAT) + self.assertEqual(dat, b'') + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + d = ZstdDecompressor() + dat = d.decompress(DAT[:-1]) + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + +class DecompressorFlagsTestCase(unittest.TestCase): + + @classmethod + def setUpClass(cls): + options = {CParameter.checksumFlag:1} + c = ZstdCompressor(options) + + cls.DECOMPRESSED_42 = b'a'*42 + cls.FRAME_42 = c.compress(cls.DECOMPRESSED_42, c.FLUSH_FRAME) + + cls.DECOMPRESSED_60 = b'a'*60 + cls.FRAME_60 = c.compress(cls.DECOMPRESSED_60, c.FLUSH_FRAME) + + cls.FRAME_42_60 = cls.FRAME_42 + cls.FRAME_60 + cls.DECOMPRESSED_42_60 = cls.DECOMPRESSED_42 + cls.DECOMPRESSED_60 + + cls._130_1K = 130*1024 + + c = ZstdCompressor() + cls.UNKNOWN_FRAME_42 = c.compress(cls.DECOMPRESSED_42) + c.flush() + cls.UNKNOWN_FRAME_60 = c.compress(cls.DECOMPRESSED_60) + c.flush() + cls.UNKNOWN_FRAME_42_60 = cls.UNKNOWN_FRAME_42 + cls.UNKNOWN_FRAME_60 + + cls.TRAIL = b'12345678abcdefg!@#$%^&*()_+|' + + def test_function_decompress(self): + + self.assertEqual(len(decompress(COMPRESSED_100_PLUS_32KB)), 100+32*1024) + + # 1 frame + self.assertEqual(decompress(self.FRAME_42), self.DECOMPRESSED_42) + + self.assertEqual(decompress(self.UNKNOWN_FRAME_42), self.DECOMPRESSED_42) + + pattern = r"Compressed data ended before the end-of-stream marker" + with self.assertRaisesRegex(ZstdError, pattern): + decompress(self.FRAME_42[:1]) + + with self.assertRaisesRegex(ZstdError, pattern): + decompress(self.FRAME_42[:-4]) + + with self.assertRaisesRegex(ZstdError, pattern): + decompress(self.FRAME_42[:-1]) + + # 2 frames + self.assertEqual(decompress(self.FRAME_42_60), self.DECOMPRESSED_42_60) + + self.assertEqual(decompress(self.UNKNOWN_FRAME_42_60), self.DECOMPRESSED_42_60) + + self.assertEqual(decompress(self.FRAME_42 + self.UNKNOWN_FRAME_60), + self.DECOMPRESSED_42_60) + + self.assertEqual(decompress(self.UNKNOWN_FRAME_42 + self.FRAME_60), + self.DECOMPRESSED_42_60) + + with self.assertRaisesRegex(ZstdError, pattern): + decompress(self.FRAME_42_60[:-4]) + + with self.assertRaisesRegex(ZstdError, pattern): + decompress(self.UNKNOWN_FRAME_42_60[:-1]) + + # 130_1K + self.assertEqual(decompress(DAT_130K_C), DAT_130K_D) + + with self.assertRaisesRegex(ZstdError, pattern): + decompress(DAT_130K_C[:-4]) + + with self.assertRaisesRegex(ZstdError, pattern): + decompress(DAT_130K_C[:-1]) + + # Unknown frame descriptor + with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"): + decompress(b'aaaaaaaaa') + + self.assertEqual( + decompress(self.FRAME_42 + b'aaaaaaaaa'), + self.DECOMPRESSED_42 + ) + + self.assertEqual( + decompress(self.UNKNOWN_FRAME_42_60 + b'aaaaaaaaa'), + self.DECOMPRESSED_42_60 + ) + + # doesn't match checksum + checksum = DAT_130K_C[-4:] + if checksum[0] == 255: + wrong_checksum = bytes([254]) + checksum[1:] + else: + wrong_checksum = bytes([checksum[0]+1]) + checksum[1:] + + dat = DAT_130K_C[:-4] + wrong_checksum + + with self.assertRaisesRegex(ZstdError, "doesn't match checksum"): + decompress(dat) + + def test_function_skippable(self): + self.assertEqual(decompress(SKIPPABLE_FRAME), b'') + self.assertEqual(decompress(SKIPPABLE_FRAME + SKIPPABLE_FRAME), b'') + + # 1 frame + 2 skippable + self.assertEqual(len(decompress(SKIPPABLE_FRAME + SKIPPABLE_FRAME + DAT_130K_C)), + self._130_1K) + + self.assertEqual(len(decompress(DAT_130K_C + SKIPPABLE_FRAME + SKIPPABLE_FRAME)), + self._130_1K) + + self.assertEqual(len(decompress(SKIPPABLE_FRAME + DAT_130K_C + SKIPPABLE_FRAME)), + self._130_1K) + + # unknown size + self.assertEqual(decompress(SKIPPABLE_FRAME + self.UNKNOWN_FRAME_60), + self.DECOMPRESSED_60) + + self.assertEqual(decompress(self.UNKNOWN_FRAME_60 + SKIPPABLE_FRAME), + self.DECOMPRESSED_60) + + # 2 frames + 1 skippable + self.assertEqual(decompress(self.FRAME_42 + SKIPPABLE_FRAME + self.FRAME_60), + self.DECOMPRESSED_42_60) + + self.assertEqual(decompress(SKIPPABLE_FRAME + self.FRAME_42_60), + self.DECOMPRESSED_42_60) + + self.assertEqual(decompress(self.UNKNOWN_FRAME_42_60 + SKIPPABLE_FRAME), + self.DECOMPRESSED_42_60) + + # incomplete + with self.assertRaises(ZstdError): + decompress(SKIPPABLE_FRAME[:1]) + + with self.assertRaises(ZstdError): + decompress(SKIPPABLE_FRAME[:-1]) + + with self.assertRaises(ZstdError): + decompress(self.FRAME_42 + SKIPPABLE_FRAME[:-1]) + + # Unknown frame descriptor + with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"): + decompress(b'aaaaaaaaa' + SKIPPABLE_FRAME) + + self.assertEqual( + decompress(SKIPPABLE_FRAME + b'aaaaaaaaa'), + b'' + ) + + self.assertEqual( + decompress(SKIPPABLE_FRAME + SKIPPABLE_FRAME + b'aaaaaaaaa'), + b'' + ) + + def test_decompressor_1(self): + # empty 1 + d = ZstdDecompressor() + + dat = d.decompress(b'') + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + dat = d.decompress(b'', 0) + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + dat = d.decompress(COMPRESSED_100_PLUS_32KB + b'a') + self.assertEqual(dat, DECOMPRESSED_100_PLUS_32KB) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'a') + self.assertEqual(d.unused_data, b'a') # twice + + # empty 2 + d = ZstdDecompressor() + + dat = d.decompress(b'', 0) + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + dat = d.decompress(b'') + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + dat = d.decompress(COMPRESSED_100_PLUS_32KB + b'a') + self.assertEqual(dat, DECOMPRESSED_100_PLUS_32KB) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'a') + self.assertEqual(d.unused_data, b'a') # twice + + # 1 frame + d = ZstdDecompressor() + dat = d.decompress(self.FRAME_42) + + self.assertEqual(dat, self.DECOMPRESSED_42) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + with self.assertRaises(EOFError): + d.decompress(b'') + + # 1 frame, trail + d = ZstdDecompressor() + dat = d.decompress(self.FRAME_42 + self.TRAIL) + + self.assertEqual(dat, self.DECOMPRESSED_42) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, self.TRAIL) + self.assertEqual(d.unused_data, self.TRAIL) # twice + + # 1 frame, 32_1K + temp = compress(b'a'*(32*1024)) + d = ZstdDecompressor() + dat = d.decompress(temp, 32*1024) + + self.assertEqual(dat, b'a'*(32*1024)) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + with self.assertRaises(EOFError): + d.decompress(b'') + + # 1 frame, 32_1K+100, trail + d = ZstdDecompressor() + dat = d.decompress(COMPRESSED_100_PLUS_32KB+self.TRAIL, 100) # 100 bytes + + self.assertEqual(len(dat), 100) + self.assertFalse(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + + dat = d.decompress(b'') # 32_1K + + self.assertEqual(len(dat), 32*1024) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, self.TRAIL) + self.assertEqual(d.unused_data, self.TRAIL) # twice + + with self.assertRaises(EOFError): + d.decompress(b'') + + # incomplete 1 + d = ZstdDecompressor() + dat = d.decompress(self.FRAME_60[:1]) + + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + # incomplete 2 + d = ZstdDecompressor() + + dat = d.decompress(self.FRAME_60[:-4]) + self.assertEqual(dat, self.DECOMPRESSED_60) + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + # incomplete 3 + d = ZstdDecompressor() + + dat = d.decompress(self.FRAME_60[:-1]) + self.assertEqual(dat, self.DECOMPRESSED_60) + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + self.assertEqual(d.unused_data, b'') + + # incomplete 4 + d = ZstdDecompressor() + + dat = d.decompress(self.FRAME_60[:-4], 60) + self.assertEqual(dat, self.DECOMPRESSED_60) + self.assertFalse(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + dat = d.decompress(b'') + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + # Unknown frame descriptor + d = ZstdDecompressor() + with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"): + d.decompress(b'aaaaaaaaa') + + def test_decompressor_skippable(self): + # 1 skippable + d = ZstdDecompressor() + dat = d.decompress(SKIPPABLE_FRAME) + + self.assertEqual(dat, b'') + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + # 1 skippable, max_length=0 + d = ZstdDecompressor() + dat = d.decompress(SKIPPABLE_FRAME, 0) + + self.assertEqual(dat, b'') + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + # 1 skippable, trail + d = ZstdDecompressor() + dat = d.decompress(SKIPPABLE_FRAME + self.TRAIL) + + self.assertEqual(dat, b'') + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, self.TRAIL) + self.assertEqual(d.unused_data, self.TRAIL) # twice + + # incomplete + d = ZstdDecompressor() + dat = d.decompress(SKIPPABLE_FRAME[:-1]) + + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + # incomplete + d = ZstdDecompressor() + dat = d.decompress(SKIPPABLE_FRAME[:-1], 0) + + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + dat = d.decompress(b'') + + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + + +class ZstdDictTestCase(unittest.TestCase): + + def test_is_raw(self): + # content < 8 + b = b'1234567' + with self.assertRaises(ValueError): + ZstdDict(b) + + # content == 8 + b = b'12345678' + zd = ZstdDict(b, is_raw=True) + self.assertEqual(zd.dict_id, 0) + + temp = compress(b'aaa12345678', level=3, zstd_dict=zd) + self.assertEqual(b'aaa12345678', decompress(temp, zd)) + + # is_raw == False + b = b'12345678abcd' + with self.assertRaises(ValueError): + ZstdDict(b) + + # read only attributes + with self.assertRaises(AttributeError): + zd.dict_content = b + + with self.assertRaises(AttributeError): + zd.dict_id = 10000 + + # ZstdDict arguments + zd = ZstdDict(TRAINED_DICT.dict_content, is_raw=False) + self.assertNotEqual(zd.dict_id, 0) + + zd = ZstdDict(TRAINED_DICT.dict_content, is_raw=True) + self.assertNotEqual(zd.dict_id, 0) # note this assertion + + with self.assertRaises(TypeError): + ZstdDict("12345678abcdef", is_raw=True) + with self.assertRaises(TypeError): + ZstdDict(TRAINED_DICT) + + # invalid parameter + with self.assertRaises(TypeError): + ZstdDict(desk333=345) + + def test_invalid_dict(self): + DICT_MAGIC = 0xEC30A437.to_bytes(4, byteorder='little') + dict_content = DICT_MAGIC + b'abcdefghighlmnopqrstuvwxyz' + + # corrupted + zd = ZstdDict(dict_content, is_raw=False) + with self.assertRaisesRegex(ZstdError, r'ZSTD_CDict.*?corrupted'): + ZstdCompressor(zstd_dict=zd.as_digested_dict) + with self.assertRaisesRegex(ZstdError, r'ZSTD_DDict.*?corrupted'): + ZstdDecompressor(zd) + + # wrong type + with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): + ZstdCompressor(zstd_dict=(zd, b'123')) + with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): + ZstdCompressor(zstd_dict=(zd, 1, 2)) + with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): + ZstdCompressor(zstd_dict=(zd, -1)) + with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): + ZstdCompressor(zstd_dict=(zd, 3)) + + with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): + ZstdDecompressor(zstd_dict=(zd, b'123')) + with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): + ZstdDecompressor((zd, 1, 2)) + with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): + ZstdDecompressor((zd, -1)) + with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): + ZstdDecompressor((zd, 3)) + + def test_train_dict(self): + + + TRAINED_DICT = train_dict(SAMPLES, DICT_SIZE1) + ZstdDict(TRAINED_DICT.dict_content, False) + + self.assertNotEqual(TRAINED_DICT.dict_id, 0) + self.assertGreater(len(TRAINED_DICT.dict_content), 0) + self.assertLessEqual(len(TRAINED_DICT.dict_content), DICT_SIZE1) + self.assertTrue(re.match(r'^$', str(TRAINED_DICT))) + + # compress/decompress + c = ZstdCompressor(zstd_dict=TRAINED_DICT) + for sample in SAMPLES: + dat1 = compress(sample, zstd_dict=TRAINED_DICT) + dat2 = decompress(dat1, TRAINED_DICT) + self.assertEqual(sample, dat2) + + dat1 = c.compress(sample) + dat1 += c.flush() + dat2 = decompress(dat1, TRAINED_DICT) + self.assertEqual(sample, dat2) + + def test_finalize_dict(self): + if zstd_version_info < (1, 4, 5): + return + + DICT_SIZE2 = 200*1024 + C_LEVEL = 6 + + try: + dic2 = finalize_dict(TRAINED_DICT, SAMPLES, DICT_SIZE2, C_LEVEL) + except NotImplementedError: + # < v1.4.5 at compile-time, >= v.1.4.5 at run-time + return + + self.assertNotEqual(dic2.dict_id, 0) + self.assertGreater(len(dic2.dict_content), 0) + self.assertLessEqual(len(dic2.dict_content), DICT_SIZE2) + + # compress/decompress + c = ZstdCompressor(C_LEVEL, zstd_dict=dic2) + for sample in SAMPLES: + dat1 = compress(sample, C_LEVEL, zstd_dict=dic2) + dat2 = decompress(dat1, dic2) + self.assertEqual(sample, dat2) + + dat1 = c.compress(sample) + dat1 += c.flush() + dat2 = decompress(dat1, dic2) + self.assertEqual(sample, dat2) + + # dict mismatch + self.assertNotEqual(TRAINED_DICT.dict_id, dic2.dict_id) + + dat1 = compress(SAMPLES[0], zstd_dict=TRAINED_DICT) + with self.assertRaises(ZstdError): + decompress(dat1, dic2) + + def test_train_dict_arguments(self): + with self.assertRaises(ValueError): + train_dict([], 100*_1K) + + with self.assertRaises(ValueError): + train_dict(SAMPLES, -100) + + with self.assertRaises(ValueError): + train_dict(SAMPLES, 0) + + def test_finalize_dict_arguments(self): + if zstd_version_info < (1, 4, 5): + with self.assertRaises(NotImplementedError): + finalize_dict({1:2}, [b'aaa', b'bbb'], 100*_1K, 2) + return + + try: + finalize_dict(TRAINED_DICT, SAMPLES, 1*_1M, 2) + except NotImplementedError: + # < v1.4.5 at compile-time, >= v.1.4.5 at run-time + return + + with self.assertRaises(ValueError): + finalize_dict(TRAINED_DICT, [], 100*_1K, 2) + + with self.assertRaises(ValueError): + finalize_dict(TRAINED_DICT, SAMPLES, -100, 2) + + with self.assertRaises(ValueError): + finalize_dict(TRAINED_DICT, SAMPLES, 0, 2) + + def test_train_dict_c(self): + # argument wrong type + with self.assertRaises(TypeError): + _zstd._train_dict({}, [], 100) + with self.assertRaises(TypeError): + _zstd._train_dict(b'', 99, 100) + with self.assertRaises(TypeError): + _zstd._train_dict(b'', [], 100.1) + + # size > size_t + with self.assertRaises(ValueError): + _zstd._train_dict(b'', [2**64+1], 100) + + # dict_size <= 0 + with self.assertRaises(ValueError): + _zstd._train_dict(b'', [], 0) + + def test_finalize_dict_c(self): + if zstd_version_info < (1, 4, 5): + with self.assertRaises(NotImplementedError): + _zstd._finalize_dict(1, 2, 3, 4, 5) + return + + try: + _zstd._finalize_dict(TRAINED_DICT.dict_content, b'123', [3,], 1*_1M, 5) + except NotImplementedError: + # < v1.4.5 at compile-time, >= v.1.4.5 at run-time + return + + # argument wrong type + with self.assertRaises(TypeError): + _zstd._finalize_dict({}, b'', [], 100, 5) + with self.assertRaises(TypeError): + _zstd._finalize_dict(TRAINED_DICT.dict_content, {}, [], 100, 5) + with self.assertRaises(TypeError): + _zstd._finalize_dict(TRAINED_DICT.dict_content, b'', 99, 100, 5) + with self.assertRaises(TypeError): + _zstd._finalize_dict(TRAINED_DICT.dict_content, b'', [], 100.1, 5) + with self.assertRaises(TypeError): + _zstd._finalize_dict(TRAINED_DICT.dict_content, b'', [], 100, 5.1) + + # size > size_t + with self.assertRaises(ValueError): + _zstd._finalize_dict(TRAINED_DICT.dict_content, b'', [2**64+1], 100, 5) + + # dict_size <= 0 + with self.assertRaises(ValueError): + _zstd._finalize_dict(TRAINED_DICT.dict_content, b'', [], 0, 5) + + def test_train_buffer_protocol_samples(self): + def _nbytes(dat): + if isinstance(dat, (bytes, bytearray)): + return len(dat) + return memoryview(dat).nbytes + + # prepare samples + chunk_lst = [] + wrong_size_lst = [] + correct_size_lst = [] + for _ in range(300): + arr = array.array('Q', [random.randint(0, 20) for i in range(20)]) + chunk_lst.append(arr) + correct_size_lst.append(_nbytes(arr)) + wrong_size_lst.append(len(arr)) + concatenation = b''.join(chunk_lst) + + # wrong size list + with self.assertRaisesRegex(ValueError, + "The samples size list doesn't match the concatenation's size"): + _zstd._train_dict(concatenation, wrong_size_lst, 100*1024) + + # correct size list + _zstd._train_dict(concatenation, correct_size_lst, 3*1024) + + # test _finalize_dict + if zstd_version_info < (1, 4, 5): + return + + # wrong size list + with self.assertRaisesRegex(ValueError, + "The samples size list doesn't match the concatenation's size"): + _zstd._finalize_dict(TRAINED_DICT.dict_content, + concatenation, wrong_size_lst, 300*1024, 5) + + # correct size list + _zstd._finalize_dict(TRAINED_DICT.dict_content, + concatenation, correct_size_lst, 300*1024, 5) + + def test_as_prefix(self): + # V1 + V1 = THIS_FILE_BYTES + zd = ZstdDict(V1, True) + + # V2 + mid = len(V1) // 2 + V2 = V1[:mid] + \ + (b'a' if V1[mid] != b'a' else b'b') + \ + V1[mid+1:] + + # compress + dat = compress(V2, zstd_dict=zd.as_prefix) + self.assertEqual(get_frame_info(dat).dictionary_id, 0) + + # decompress + self.assertEqual(decompress(dat, zd.as_prefix), V2) + + # use wrong prefix + zd2 = ZstdDict(SAMPLES[0], True) + try: + decompressed = decompress(dat, zd2.as_prefix) + except ZstdError: # expected + pass + else: + self.assertNotEqual(decompressed, V2) + + # read only attribute + with self.assertRaises(AttributeError): + zd.as_prefix = b'1234' + + def test_as_digested_dict(self): + zd = TRAINED_DICT + + # test .as_digested_dict + dat = compress(SAMPLES[0], zstd_dict=zd.as_digested_dict) + self.assertEqual(decompress(dat, zd.as_digested_dict), SAMPLES[0]) + with self.assertRaises(AttributeError): + zd.as_digested_dict = b'1234' + + # test .as_undigested_dict + dat = compress(SAMPLES[0], zstd_dict=zd.as_undigested_dict) + self.assertEqual(decompress(dat, zd.as_undigested_dict), SAMPLES[0]) + with self.assertRaises(AttributeError): + zd.as_undigested_dict = b'1234' + + def test_advanced_compression_parameters(self): + options = {CParameter.compressionLevel: 6, + CParameter.windowLog: 20, + CParameter.enableLongDistanceMatching: 1} + + # automatically select + dat = compress(SAMPLES[0], options=options, zstd_dict=TRAINED_DICT) + self.assertEqual(decompress(dat, TRAINED_DICT), SAMPLES[0]) + + # explicitly select + dat = compress(SAMPLES[0], options=options, zstd_dict=TRAINED_DICT.as_digested_dict) + self.assertEqual(decompress(dat, TRAINED_DICT), SAMPLES[0]) + + def test_len(self): + self.assertEqual(len(TRAINED_DICT), len(TRAINED_DICT.dict_content)) + self.assertIn(str(len(TRAINED_DICT)), str(TRAINED_DICT)) + +class FileTestCase(unittest.TestCase): + def setUp(self): + self.DECOMPRESSED_42 = b'a'*42 + self.FRAME_42 = compress(self.DECOMPRESSED_42) + + def test_init(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + pass + with ZstdFile(io.BytesIO(), "w") as f: + pass + with ZstdFile(io.BytesIO(), "x") as f: + pass + with ZstdFile(io.BytesIO(), "a") as f: + pass + + with ZstdFile(io.BytesIO(), "w", level=12) as f: + pass + with ZstdFile(io.BytesIO(), "w", options={CParameter.checksumFlag:1}) as f: + pass + with ZstdFile(io.BytesIO(), "w", options={}) as f: + pass + with ZstdFile(io.BytesIO(), "w", level=20, zstd_dict=TRAINED_DICT) as f: + pass + + with ZstdFile(io.BytesIO(), "r", options={DParameter.windowLogMax:25}) as f: + pass + with ZstdFile(io.BytesIO(), "r", options={}, zstd_dict=TRAINED_DICT) as f: + pass + + def test_init_with_PathLike_filename(self): + with tempfile.NamedTemporaryFile(delete=False) as tmp_f: + filename = pathlib.Path(tmp_f.name) + + with ZstdFile(filename, "a") as f: + f.write(DECOMPRESSED_100_PLUS_32KB) + with ZstdFile(filename) as f: + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB) + + with ZstdFile(filename, "a") as f: + f.write(DECOMPRESSED_100_PLUS_32KB) + with ZstdFile(filename) as f: + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB * 2) + + os.remove(filename) + + def test_init_with_filename(self): + with tempfile.NamedTemporaryFile(delete=False) as tmp_f: + filename = pathlib.Path(tmp_f.name) + + with ZstdFile(filename) as f: + pass + with ZstdFile(filename, "w") as f: + pass + with ZstdFile(filename, "a") as f: + pass + + os.remove(filename) + + def test_init_mode(self): + bi = io.BytesIO() + + with ZstdFile(bi, "r"): + pass + with ZstdFile(bi, "rb"): + pass + with ZstdFile(bi, "w"): + pass + with ZstdFile(bi, "wb"): + pass + with ZstdFile(bi, "a"): + pass + with ZstdFile(bi, "ab"): + pass + + def test_init_with_x_mode(self): + with tempfile.NamedTemporaryFile() as tmp_f: + filename = pathlib.Path(tmp_f.name) + + for mode in ("x", "xb"): + with ZstdFile(filename, mode): + pass + with self.assertRaises(FileExistsError): + with ZstdFile(filename, mode): + pass + os.remove(filename) + + def test_init_bad_mode(self): + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), (3, "x")) + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "") + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "xt") + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "x+") + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rx") + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "wx") + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rt") + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r+") + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "wt") + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "w+") + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rw") + + with self.assertRaisesRegex(TypeError, r"NOT be CParameter"): + ZstdFile(io.BytesIO(), 'rb', options={CParameter.compressionLevel:5}) + with self.assertRaisesRegex(TypeError, r"NOT be DParameter"): + ZstdFile(io.BytesIO(), 'wb', options={DParameter.windowLogMax:21}) + + with self.assertRaises(TypeError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r", level=12) + + def test_init_bad_check(self): + with self.assertRaises(TypeError): + ZstdFile(io.BytesIO(), "w", level='asd') + # CHECK_UNKNOWN and anything above CHECK_ID_MAX should be invalid. + with self.assertRaises(ZstdError): + ZstdFile(io.BytesIO(), "w", options={999:9999}) + with self.assertRaises(ZstdError): + ZstdFile(io.BytesIO(), "w", options={CParameter.windowLog:99}) + + with self.assertRaises(TypeError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r", options=33) + + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), + options={DParameter.windowLogMax:2**31}) + + with self.assertRaises(ZstdError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), + options={444:333}) + + with self.assertRaises(TypeError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), zstd_dict={1:2}) + + with self.assertRaises(TypeError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), zstd_dict=b'dict123456') + + def test_init_close_fp(self): + # get a temp file name + with tempfile.NamedTemporaryFile(delete=False) as tmp_f: + tmp_f.write(DAT_130K_C) + filename = tmp_f.name + + with self.assertRaises(ValueError): + ZstdFile(filename, options={'a':'b'}) + + # for PyPy + gc.collect() + + os.remove(filename) + + def test_close(self): + with io.BytesIO(COMPRESSED_100_PLUS_32KB) as src: + f = ZstdFile(src) + f.close() + # ZstdFile.close() should not close the underlying file object. + self.assertFalse(src.closed) + # Try closing an already-closed ZstdFile. + f.close() + self.assertFalse(src.closed) + + # Test with a real file on disk, opened directly by ZstdFile. + with tempfile.NamedTemporaryFile(delete=False) as tmp_f: + filename = pathlib.Path(tmp_f.name) + + f = ZstdFile(filename) + fp = f._fp + f.close() + # Here, ZstdFile.close() *should* close the underlying file object. + self.assertTrue(fp.closed) + # Try closing an already-closed ZstdFile. + f.close() + + os.remove(filename) + + def test_closed(self): + f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) + try: + self.assertFalse(f.closed) + f.read() + self.assertFalse(f.closed) + finally: + f.close() + self.assertTrue(f.closed) + + f = ZstdFile(io.BytesIO(), "w") + try: + self.assertFalse(f.closed) + finally: + f.close() + self.assertTrue(f.closed) + + def test_fileno(self): + # 1 + f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) + try: + self.assertRaises(io.UnsupportedOperation, f.fileno) + finally: + f.close() + self.assertRaises(ValueError, f.fileno) + + # 2 + with tempfile.NamedTemporaryFile(delete=False) as tmp_f: + filename = pathlib.Path(tmp_f.name) + + f = ZstdFile(filename) + try: + self.assertEqual(f.fileno(), f._fp.fileno()) + self.assertIsInstance(f.fileno(), int) + finally: + f.close() + self.assertRaises(ValueError, f.fileno) + + os.remove(filename) + + # 3, no .fileno() method + class C: + def read(self, size=-1): + return b'123' + with ZstdFile(C(), 'rb') as f: + with self.assertRaisesRegex(AttributeError, r'fileno'): + f.fileno() + + def test_seekable(self): + f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) + try: + self.assertTrue(f.seekable()) + f.read() + self.assertTrue(f.seekable()) + finally: + f.close() + self.assertRaises(ValueError, f.seekable) + + f = ZstdFile(io.BytesIO(), "w") + try: + self.assertFalse(f.seekable()) + finally: + f.close() + self.assertRaises(ValueError, f.seekable) + + src = io.BytesIO(COMPRESSED_100_PLUS_32KB) + src.seekable = lambda: False + f = ZstdFile(src) + try: + self.assertFalse(f.seekable()) + finally: + f.close() + self.assertRaises(ValueError, f.seekable) + + def test_readable(self): + f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) + try: + self.assertTrue(f.readable()) + f.read() + self.assertTrue(f.readable()) + finally: + f.close() + self.assertRaises(ValueError, f.readable) + + f = ZstdFile(io.BytesIO(), "w") + try: + self.assertFalse(f.readable()) + finally: + f.close() + self.assertRaises(ValueError, f.readable) + + def test_writable(self): + f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) + try: + self.assertFalse(f.writable()) + f.read() + self.assertFalse(f.writable()) + finally: + f.close() + self.assertRaises(ValueError, f.writable) + + f = ZstdFile(io.BytesIO(), "w") + try: + self.assertTrue(f.writable()) + finally: + f.close() + self.assertRaises(ValueError, f.writable) + + def test_read_0(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + self.assertEqual(f.read(0), b"") + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB) + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), + options={DParameter.windowLogMax:20}) as f: + self.assertEqual(f.read(0), b"") + + # empty file + with ZstdFile(io.BytesIO(b'')) as f: + self.assertEqual(f.read(0), b"") + with self.assertRaises(EOFError): + f.read(10) + + with ZstdFile(io.BytesIO(b'')) as f: + with self.assertRaises(EOFError): + f.read(10) + + def test_read_10(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + chunks = [] + while True: + result = f.read(10) + if not result: + break + self.assertLessEqual(len(result), 10) + chunks.append(result) + self.assertEqual(b"".join(chunks), DECOMPRESSED_100_PLUS_32KB) + + def test_read_multistream(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB * 5)) as f: + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB * 5) + + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB + SKIPPABLE_FRAME)) as f: + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB) + + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB + COMPRESSED_DAT)) as f: + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB + DECOMPRESSED_DAT) + + def test_read_incomplete(self): + with ZstdFile(io.BytesIO(DAT_130K_C[:-200])) as f: + self.assertRaises(EOFError, f.read) + + # Trailing data isn't a valid compressed stream + with ZstdFile(io.BytesIO(self.FRAME_42 + b'12345')) as f: + self.assertEqual(f.read(), self.DECOMPRESSED_42) + + with ZstdFile(io.BytesIO(SKIPPABLE_FRAME + b'12345')) as f: + self.assertEqual(f.read(), b'') + + def test_read_truncated(self): + # Drop stream epilogue: 4 bytes checksum + truncated = DAT_130K_C[:-4] + with ZstdFile(io.BytesIO(truncated)) as f: + self.assertRaises(EOFError, f.read) + + with ZstdFile(io.BytesIO(truncated)) as f: + # this is an important test, make sure it doesn't raise EOFError. + self.assertEqual(f.read(130*1024), DAT_130K_D) + with self.assertRaises(EOFError): + f.read(1) + + # Incomplete header + for i in range(1, 20): + with ZstdFile(io.BytesIO(truncated[:i])) as f: + self.assertRaises(EOFError, f.read, 1) + + def test_read_bad_args(self): + f = ZstdFile(io.BytesIO(COMPRESSED_DAT)) + f.close() + self.assertRaises(ValueError, f.read) + with ZstdFile(io.BytesIO(), "w") as f: + self.assertRaises(ValueError, f.read) + with ZstdFile(io.BytesIO(COMPRESSED_DAT)) as f: + self.assertRaises(TypeError, f.read, float()) + + def test_read_bad_data(self): + with ZstdFile(io.BytesIO(COMPRESSED_BOGUS)) as f: + self.assertRaises(ZstdError, f.read) + + def test_read_exception(self): + class C: + def read(self, size=-1): + raise OSError + with ZstdFile(C()) as f: + with self.assertRaises(OSError): + f.read(10) + + def test_read1(self): + with ZstdFile(io.BytesIO(DAT_130K_C)) as f: + blocks = [] + while True: + result = f.read1() + if not result: + break + blocks.append(result) + self.assertEqual(b"".join(blocks), DAT_130K_D) + self.assertEqual(f.read1(), b"") + + def test_read1_0(self): + with ZstdFile(io.BytesIO(COMPRESSED_DAT)) as f: + self.assertEqual(f.read1(0), b"") + + def test_read1_10(self): + with ZstdFile(io.BytesIO(COMPRESSED_DAT)) as f: + blocks = [] + while True: + result = f.read1(10) + if not result: + break + blocks.append(result) + self.assertEqual(b"".join(blocks), DECOMPRESSED_DAT) + self.assertEqual(f.read1(), b"") + + def test_read1_multistream(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB * 5)) as f: + blocks = [] + while True: + result = f.read1() + if not result: + break + blocks.append(result) + self.assertEqual(b"".join(blocks), DECOMPRESSED_100_PLUS_32KB * 5) + self.assertEqual(f.read1(), b"") + + def test_read1_bad_args(self): + f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) + f.close() + self.assertRaises(ValueError, f.read1) + with ZstdFile(io.BytesIO(), "w") as f: + self.assertRaises(ValueError, f.read1) + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + self.assertRaises(TypeError, f.read1, None) + + def test_readinto(self): + arr = array.array("I", range(100)) + self.assertEqual(len(arr), 100) + self.assertEqual(len(arr) * arr.itemsize, 400) + ba = bytearray(300) + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + # 0 length output buffer + self.assertEqual(f.readinto(ba[0:0]), 0) + + # use correct length for buffer protocol object + self.assertEqual(f.readinto(arr), 400) + self.assertEqual(arr.tobytes(), DECOMPRESSED_100_PLUS_32KB[:400]) + + # normal readinto + self.assertEqual(f.readinto(ba), 300) + self.assertEqual(ba, DECOMPRESSED_100_PLUS_32KB[400:700]) + + def test_peek(self): + with ZstdFile(io.BytesIO(DAT_130K_C)) as f: + result = f.peek() + self.assertGreater(len(result), 0) + self.assertTrue(DAT_130K_D.startswith(result)) + self.assertEqual(f.read(), DAT_130K_D) + with ZstdFile(io.BytesIO(DAT_130K_C)) as f: + result = f.peek(10) + self.assertGreater(len(result), 0) + self.assertTrue(DAT_130K_D.startswith(result)) + self.assertEqual(f.read(), DAT_130K_D) + + def test_peek_bad_args(self): + with ZstdFile(io.BytesIO(), "w") as f: + self.assertRaises(ValueError, f.peek) + + def test_iterator(self): + with io.BytesIO(THIS_FILE_BYTES) as f: + lines = f.readlines() + compressed = compress(THIS_FILE_BYTES) + + # iter + with ZstdFile(io.BytesIO(compressed)) as f: + self.assertListEqual(list(iter(f)), lines) + + # readline + with ZstdFile(io.BytesIO(compressed)) as f: + for line in lines: + self.assertEqual(f.readline(), line) + self.assertEqual(f.readline(), b'') + self.assertEqual(f.readline(), b'') + + # readlines + with ZstdFile(io.BytesIO(compressed)) as f: + self.assertListEqual(f.readlines(), lines) + + def test_decompress_limited(self): + _ZSTD_DStreamInSize = 128*1024 + 3 + + bomb = compress(b'\0' * int(2e6), level=10) + self.assertLess(len(bomb), _ZSTD_DStreamInSize) + + decomp = ZstdFile(io.BytesIO(bomb)) + self.assertEqual(decomp.read(1), b'\0') + + # BufferedReader uses 128 KiB buffer in __init__.py + max_decomp = 128*1024 + self.assertLessEqual(decomp._buffer.raw.tell(), max_decomp, + "Excessive amount of data was decompressed") + + def test_write(self): + raw_data = THIS_FILE_BYTES[: len(THIS_FILE_BYTES) // 6] + with io.BytesIO() as dst: + with ZstdFile(dst, "w") as f: + f.write(raw_data) + + comp = ZstdCompressor() + expected = comp.compress(raw_data) + comp.flush() + self.assertEqual(dst.getvalue(), expected) + + with io.BytesIO() as dst: + with ZstdFile(dst, "w", level=12) as f: + f.write(raw_data) + + comp = ZstdCompressor(12) + expected = comp.compress(raw_data) + comp.flush() + self.assertEqual(dst.getvalue(), expected) + + with io.BytesIO() as dst: + with ZstdFile(dst, "w", options={CParameter.checksumFlag:1}) as f: + f.write(raw_data) + + comp = ZstdCompressor({CParameter.checksumFlag:1}) + expected = comp.compress(raw_data) + comp.flush() + self.assertEqual(dst.getvalue(), expected) + + with io.BytesIO() as dst: + options = {CParameter.compressionLevel:-5, + CParameter.checksumFlag:1} + with ZstdFile(dst, "w", + options=options) as f: + f.write(raw_data) + + comp = ZstdCompressor(options=options) + expected = comp.compress(raw_data) + comp.flush() + self.assertEqual(dst.getvalue(), expected) + + def test_write_empty_frame(self): + # .FLUSH_FRAME generates an empty content frame + c = ZstdCompressor() + self.assertNotEqual(c.flush(c.FLUSH_FRAME), b'') + self.assertNotEqual(c.flush(c.FLUSH_FRAME), b'') + + # don't generate empty content frame + bo = io.BytesIO() + with ZstdFile(bo, 'w') as f: + pass + self.assertEqual(bo.getvalue(), b'') + + bo = io.BytesIO() + with ZstdFile(bo, 'w') as f: + f.flush(f.FLUSH_FRAME) + self.assertEqual(bo.getvalue(), b'') + + # if .write(b''), generate empty content frame + bo = io.BytesIO() + with ZstdFile(bo, 'w') as f: + f.write(b'') + self.assertNotEqual(bo.getvalue(), b'') + + # has an empty content frame + bo = io.BytesIO() + with ZstdFile(bo, 'w') as f: + f.flush(f.FLUSH_BLOCK) + self.assertNotEqual(bo.getvalue(), b'') + + def test_write_empty_block(self): + # If no internal data, .FLUSH_BLOCK return b''. + c = ZstdCompressor() + self.assertEqual(c.flush(c.FLUSH_BLOCK), b'') + self.assertNotEqual(c.compress(b'123', c.FLUSH_BLOCK), + b'') + self.assertEqual(c.flush(c.FLUSH_BLOCK), b'') + self.assertEqual(c.compress(b''), b'') + self.assertEqual(c.compress(b''), b'') + self.assertEqual(c.flush(c.FLUSH_BLOCK), b'') + + # mode = .last_mode + bo = io.BytesIO() + with ZstdFile(bo, 'w') as f: + f.write(b'123') + f.flush(f.FLUSH_BLOCK) + fp_pos = f._fp.tell() + self.assertNotEqual(fp_pos, 0) + f.flush(f.FLUSH_BLOCK) + self.assertEqual(f._fp.tell(), fp_pos) + + # mode != .last_mode + bo = io.BytesIO() + with ZstdFile(bo, 'w') as f: + f.flush(f.FLUSH_BLOCK) + self.assertEqual(f._fp.tell(), 0) + f.write(b'') + f.flush(f.FLUSH_BLOCK) + self.assertEqual(f._fp.tell(), 0) + + def test_write_101(self): + with io.BytesIO() as dst: + with ZstdFile(dst, "w") as f: + for start in range(0, len(THIS_FILE_BYTES), 101): + f.write(THIS_FILE_BYTES[start:start+101]) + + comp = ZstdCompressor() + expected = comp.compress(THIS_FILE_BYTES) + comp.flush() + self.assertEqual(dst.getvalue(), expected) + + def test_write_append(self): + def comp(data): + comp = ZstdCompressor() + return comp.compress(data) + comp.flush() + + part1 = THIS_FILE_BYTES[:1024] + part2 = THIS_FILE_BYTES[1024:1536] + part3 = THIS_FILE_BYTES[1536:] + expected = b"".join(comp(x) for x in (part1, part2, part3)) + with io.BytesIO() as dst: + with ZstdFile(dst, "w") as f: + f.write(part1) + with ZstdFile(dst, "a") as f: + f.write(part2) + with ZstdFile(dst, "a") as f: + f.write(part3) + self.assertEqual(dst.getvalue(), expected) + + def test_write_bad_args(self): + f = ZstdFile(io.BytesIO(), "w") + f.close() + self.assertRaises(ValueError, f.write, b"foo") + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r") as f: + self.assertRaises(ValueError, f.write, b"bar") + with ZstdFile(io.BytesIO(), "w") as f: + self.assertRaises(TypeError, f.write, None) + self.assertRaises(TypeError, f.write, "text") + self.assertRaises(TypeError, f.write, 789) + + def test_writelines(self): + def comp(data): + comp = ZstdCompressor() + return comp.compress(data) + comp.flush() + + with io.BytesIO(THIS_FILE_BYTES) as f: + lines = f.readlines() + with io.BytesIO() as dst: + with ZstdFile(dst, "w") as f: + f.writelines(lines) + expected = comp(THIS_FILE_BYTES) + self.assertEqual(dst.getvalue(), expected) + + def test_seek_forward(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + f.seek(555) + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[555:]) + + def test_seek_forward_across_streams(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB * 2)) as f: + f.seek(len(DECOMPRESSED_100_PLUS_32KB) + 123) + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[123:]) + + def test_seek_forward_relative_to_current(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + f.read(100) + f.seek(1236, 1) + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[1336:]) + + def test_seek_forward_relative_to_end(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + f.seek(-555, 2) + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[-555:]) + + def test_seek_backward(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + f.read(1001) + f.seek(211) + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[211:]) + + def test_seek_backward_across_streams(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB * 2)) as f: + f.read(len(DECOMPRESSED_100_PLUS_32KB) + 333) + f.seek(737) + self.assertEqual(f.read(), + DECOMPRESSED_100_PLUS_32KB[737:] + DECOMPRESSED_100_PLUS_32KB) + + def test_seek_backward_relative_to_end(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + f.seek(-150, 2) + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[-150:]) + + def test_seek_past_end(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + f.seek(len(DECOMPRESSED_100_PLUS_32KB) + 9001) + self.assertEqual(f.tell(), len(DECOMPRESSED_100_PLUS_32KB)) + self.assertEqual(f.read(), b"") + + def test_seek_past_start(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + f.seek(-88) + self.assertEqual(f.tell(), 0) + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB) + + def test_seek_bad_args(self): + f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) + f.close() + self.assertRaises(ValueError, f.seek, 0) + with ZstdFile(io.BytesIO(), "w") as f: + self.assertRaises(ValueError, f.seek, 0) + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + self.assertRaises(ValueError, f.seek, 0, 3) + # io.BufferedReader raises TypeError instead of ValueError + self.assertRaises((TypeError, ValueError), f.seek, 9, ()) + self.assertRaises(TypeError, f.seek, None) + self.assertRaises(TypeError, f.seek, b"derp") + + def test_seek_not_seekable(self): + class C(io.BytesIO): + def seekable(self): + return False + obj = C(COMPRESSED_100_PLUS_32KB) + with ZstdFile(obj, 'r') as f: + d = f.read(1) + self.assertFalse(f.seekable()) + with self.assertRaisesRegex(io.UnsupportedOperation, + 'File or stream is not seekable'): + f.seek(0) + d += f.read() + self.assertEqual(d, DECOMPRESSED_100_PLUS_32KB) + + def test_tell(self): + with ZstdFile(io.BytesIO(DAT_130K_C)) as f: + pos = 0 + while True: + self.assertEqual(f.tell(), pos) + result = f.read(random.randint(171, 189)) + if not result: + break + pos += len(result) + self.assertEqual(f.tell(), len(DAT_130K_D)) + with ZstdFile(io.BytesIO(), "w") as f: + for pos in range(0, len(DAT_130K_D), 143): + self.assertEqual(f.tell(), pos) + f.write(DAT_130K_D[pos:pos+143]) + self.assertEqual(f.tell(), len(DAT_130K_D)) + + def test_tell_bad_args(self): + f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) + f.close() + self.assertRaises(ValueError, f.tell) + + def test_file_dict(self): + # default + bi = io.BytesIO() + with ZstdFile(bi, 'w', zstd_dict=TRAINED_DICT) as f: + f.write(SAMPLES[0]) + bi.seek(0) + with ZstdFile(bi, zstd_dict=TRAINED_DICT) as f: + dat = f.read() + self.assertEqual(dat, SAMPLES[0]) + + # .as_(un)digested_dict + bi = io.BytesIO() + with ZstdFile(bi, 'w', zstd_dict=TRAINED_DICT.as_digested_dict) as f: + f.write(SAMPLES[0]) + bi.seek(0) + with ZstdFile(bi, zstd_dict=TRAINED_DICT.as_undigested_dict) as f: + dat = f.read() + self.assertEqual(dat, SAMPLES[0]) + + def test_file_prefix(self): + bi = io.BytesIO() + with ZstdFile(bi, 'w', zstd_dict=TRAINED_DICT.as_prefix) as f: + f.write(SAMPLES[0]) + bi.seek(0) + with ZstdFile(bi, zstd_dict=TRAINED_DICT.as_prefix) as f: + dat = f.read() + self.assertEqual(dat, SAMPLES[0]) + + def test_UnsupportedOperation(self): + # 1 + with ZstdFile(io.BytesIO(), 'r') as f: + with self.assertRaises(io.UnsupportedOperation): + f.write(b'1234') + + # 2 + class T: + def read(self, size): + return b'a' * size + + with self.assertRaises(AttributeError): # on close + with ZstdFile(T(), 'w') as f: + with self.assertRaises(AttributeError): # on write + f.write(b'1234') + + # 3 + with ZstdFile(io.BytesIO(), 'w') as f: + with self.assertRaises(io.UnsupportedOperation): + f.read(100) + with self.assertRaises(io.UnsupportedOperation): + f.seek(100) + self.assertEqual(f.closed, True) + with self.assertRaises(ValueError): + f.readable() + with self.assertRaises(ValueError): + f.tell() + with self.assertRaises(ValueError): + f.read(100) + + def test_read_readinto_readinto1(self): + lst = [] + with ZstdFile(io.BytesIO(COMPRESSED_THIS_FILE*5)) as f: + while True: + method = random.randint(0, 2) + size = random.randint(0, 300) + + if method == 0: + dat = f.read(size) + if not dat and size: + break + lst.append(dat) + elif method == 1: + ba = bytearray(size) + read_size = f.readinto(ba) + if read_size == 0 and size: + break + lst.append(bytes(ba[:read_size])) + elif method == 2: + ba = bytearray(size) + read_size = f.readinto1(ba) + if read_size == 0 and size: + break + lst.append(bytes(ba[:read_size])) + self.assertEqual(b''.join(lst), THIS_FILE_BYTES*5) + + def test_zstdfile_flush(self): + # closed + f = ZstdFile(io.BytesIO(), 'w') + f.close() + with self.assertRaises(ValueError): + f.flush() + + # read + with ZstdFile(io.BytesIO(), 'r') as f: + # does nothing for read-only stream + f.flush() + + # write + DAT = b'abcd' + bi = io.BytesIO() + with ZstdFile(bi, 'w') as f: + self.assertEqual(f.write(DAT), len(DAT)) + self.assertEqual(f.tell(), len(DAT)) + self.assertEqual(bi.tell(), 0) # not enough for a block + + self.assertEqual(f.flush(), None) + self.assertEqual(f.tell(), len(DAT)) + self.assertGreater(bi.tell(), 0) # flushed + + # write, no .flush() method + class C: + def write(self, b): + return len(b) + with ZstdFile(C(), 'w') as f: + self.assertEqual(f.write(DAT), len(DAT)) + self.assertEqual(f.tell(), len(DAT)) + + self.assertEqual(f.flush(), None) + self.assertEqual(f.tell(), len(DAT)) + + def test_zstdfile_flush_mode(self): + self.assertEqual(ZstdFile.FLUSH_BLOCK, ZstdCompressor.FLUSH_BLOCK) + self.assertEqual(ZstdFile.FLUSH_FRAME, ZstdCompressor.FLUSH_FRAME) + with self.assertRaises(AttributeError): + ZstdFile.CONTINUE + + bo = io.BytesIO() + with ZstdFile(bo, 'w') as f: + # flush block + self.assertEqual(f.write(b'123'), 3) + self.assertIsNone(f.flush(f.FLUSH_BLOCK)) + p1 = bo.tell() + # mode == .last_mode, should return + self.assertIsNone(f.flush()) + p2 = bo.tell() + self.assertEqual(p1, p2) + # flush frame + self.assertEqual(f.write(b'456'), 3) + self.assertIsNone(f.flush(mode=f.FLUSH_FRAME)) + # flush frame + self.assertEqual(f.write(b'789'), 3) + self.assertIsNone(f.flush(f.FLUSH_FRAME)) + p1 = bo.tell() + # mode == .last_mode, should return + self.assertIsNone(f.flush(f.FLUSH_FRAME)) + p2 = bo.tell() + self.assertEqual(p1, p2) + self.assertEqual(decompress(bo.getvalue()), b'123456789') + + bo = io.BytesIO() + with ZstdFile(bo, 'w') as f: + f.write(b'123') + with self.assertRaisesRegex(ValueError, r'\.FLUSH_.*?\.FLUSH_'): + f.flush(ZstdCompressor.CONTINUE) + with self.assertRaises(ValueError): + f.flush(-1) + with self.assertRaises(ValueError): + f.flush(123456) + with self.assertRaises(TypeError): + f.flush(node=ZstdCompressor.CONTINUE) + with self.assertRaises((TypeError, ValueError)): + f.flush('FLUSH_FRAME') + with self.assertRaises(TypeError): + f.flush(b'456', f.FLUSH_BLOCK) + + def test_zstdfile_truncate(self): + with ZstdFile(io.BytesIO(), 'w') as f: + with self.assertRaises(io.UnsupportedOperation): + f.truncate(200) + + def test_zstdfile_iter_issue45475(self): + lines = [l for l in ZstdFile(io.BytesIO(COMPRESSED_THIS_FILE))] + self.assertGreater(len(lines), 0) + + def test_append_new_file(self): + with tempfile.NamedTemporaryFile(delete=True) as tmp_f: + filename = tmp_f.name + + with ZstdFile(filename, 'a') as f: + pass + self.assertTrue(os.path.isfile(filename)) + + os.remove(filename) + +class OpenTestCase(unittest.TestCase): + + def test_binary_modes(self): + with open(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rb") as f: + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB) + with io.BytesIO() as bio: + with open(bio, "wb") as f: + f.write(DECOMPRESSED_100_PLUS_32KB) + file_data = decompress(bio.getvalue()) + self.assertEqual(file_data, DECOMPRESSED_100_PLUS_32KB) + with open(bio, "ab") as f: + f.write(DECOMPRESSED_100_PLUS_32KB) + file_data = decompress(bio.getvalue()) + self.assertEqual(file_data, DECOMPRESSED_100_PLUS_32KB * 2) + + def test_text_modes(self): + # empty input + with self.assertRaises(EOFError): + with open(io.BytesIO(b''), "rt", encoding="utf-8", newline='\n') as reader: + for _ in reader: + pass + + # read + uncompressed = THIS_FILE_STR.replace(os.linesep, "\n") + with open(io.BytesIO(COMPRESSED_THIS_FILE), "rt", encoding="utf-8") as f: + self.assertEqual(f.read(), uncompressed) + + with io.BytesIO() as bio: + # write + with open(bio, "wt", encoding="utf-8") as f: + f.write(uncompressed) + file_data = decompress(bio.getvalue()).decode("utf-8") + self.assertEqual(file_data.replace(os.linesep, "\n"), uncompressed) + # append + with open(bio, "at", encoding="utf-8") as f: + f.write(uncompressed) + file_data = decompress(bio.getvalue()).decode("utf-8") + self.assertEqual(file_data.replace(os.linesep, "\n"), uncompressed * 2) + + def test_bad_params(self): + with tempfile.NamedTemporaryFile(delete=False) as tmp_f: + TESTFN = pathlib.Path(tmp_f.name) + + with self.assertRaises(ValueError): + open(TESTFN, "") + with self.assertRaises(ValueError): + open(TESTFN, "rbt") + with self.assertRaises(ValueError): + open(TESTFN, "rb", encoding="utf-8") + with self.assertRaises(ValueError): + open(TESTFN, "rb", errors="ignore") + with self.assertRaises(ValueError): + open(TESTFN, "rb", newline="\n") + + os.remove(TESTFN) + + def test_option(self): + options = {DParameter.windowLogMax:25} + with open(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rb", options=options) as f: + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB) + + options = {CParameter.compressionLevel:12} + with io.BytesIO() as bio: + with open(bio, "wb", options=options) as f: + f.write(DECOMPRESSED_100_PLUS_32KB) + file_data = decompress(bio.getvalue()) + self.assertEqual(file_data, DECOMPRESSED_100_PLUS_32KB) + + def test_encoding(self): + uncompressed = THIS_FILE_STR.replace(os.linesep, "\n") + + with io.BytesIO() as bio: + with open(bio, "wt", encoding="utf-16-le") as f: + f.write(uncompressed) + file_data = decompress(bio.getvalue()).decode("utf-16-le") + self.assertEqual(file_data.replace(os.linesep, "\n"), uncompressed) + bio.seek(0) + with open(bio, "rt", encoding="utf-16-le") as f: + self.assertEqual(f.read().replace(os.linesep, "\n"), uncompressed) + + def test_encoding_error_handler(self): + with io.BytesIO(compress(b"foo\xffbar")) as bio: + with open(bio, "rt", encoding="ascii", errors="ignore") as f: + self.assertEqual(f.read(), "foobar") + + def test_newline(self): + # Test with explicit newline (universal newline mode disabled). + text = THIS_FILE_STR.replace(os.linesep, "\n") + with io.BytesIO() as bio: + with open(bio, "wt", encoding="utf-8", newline="\n") as f: + f.write(text) + bio.seek(0) + with open(bio, "rt", encoding="utf-8", newline="\r") as f: + self.assertEqual(f.readlines(), [text]) + + def test_x_mode(self): + with tempfile.NamedTemporaryFile(delete=False) as tmp_f: + TESTFN = pathlib.Path(tmp_f.name) + + for mode in ("x", "xb", "xt"): + os.remove(TESTFN) + + if mode == "xt": + encoding = "utf-8" + else: + encoding = None + with open(TESTFN, mode, encoding=encoding): + pass + with self.assertRaises(FileExistsError): + with open(TESTFN, mode): + pass + + os.remove(TESTFN) + + def test_open_dict(self): + # default + bi = io.BytesIO() + with open(bi, 'w', zstd_dict=TRAINED_DICT) as f: + f.write(SAMPLES[0]) + bi.seek(0) + with open(bi, zstd_dict=TRAINED_DICT) as f: + dat = f.read() + self.assertEqual(dat, SAMPLES[0]) + + # .as_(un)digested_dict + bi = io.BytesIO() + with open(bi, 'w', zstd_dict=TRAINED_DICT.as_digested_dict) as f: + f.write(SAMPLES[0]) + bi.seek(0) + with open(bi, zstd_dict=TRAINED_DICT.as_undigested_dict) as f: + dat = f.read() + self.assertEqual(dat, SAMPLES[0]) + + # invalid dictionary + bi = io.BytesIO() + with self.assertRaisesRegex(TypeError, 'zstd_dict'): + open(bi, 'w', zstd_dict={1:2, 2:3}) + + with self.assertRaisesRegex(TypeError, 'zstd_dict'): + open(bi, 'w', zstd_dict=b'1234567890') + + def test_open_prefix(self): + bi = io.BytesIO() + with open(bi, 'w', zstd_dict=TRAINED_DICT.as_prefix) as f: + f.write(SAMPLES[0]) + bi.seek(0) + with open(bi, zstd_dict=TRAINED_DICT.as_prefix) as f: + dat = f.read() + self.assertEqual(dat, SAMPLES[0]) + + def test_buffer_protocol(self): + # don't use len() for buffer protocol objects + arr = array.array("i", range(1000)) + LENGTH = len(arr) * arr.itemsize + + with open(io.BytesIO(), "wb") as f: + self.assertEqual(f.write(arr), LENGTH) + self.assertEqual(f.tell(), LENGTH) + +class FreeThreadingMethodTests(unittest.TestCase): + + @unittest.skipUnless(Py_GIL_DISABLED, 'this test can only possibly fail with GIL disabled') + @threading_helper.reap_threads + @threading_helper.requires_working_threading() + def test_compress_locking(self): + input = b'a'* (16*_1K) + num_threads = 8 + + comp = ZstdCompressor() + parts = [] + for _ in range(num_threads): + res = comp.compress(input, ZstdCompressor.FLUSH_BLOCK) + if res: + parts.append(res) + rest1 = comp.flush() + expected = b''.join(parts) + rest1 + + comp = ZstdCompressor() + output = [] + def run_method(method, input_data, output_data): + res = method(input_data, ZstdCompressor.FLUSH_BLOCK) + if res: + output_data.append(res) + threads = [] + + for i in range(num_threads): + thread = threading.Thread(target=run_method, args=(comp.compress, input, output)) + + threads.append(thread) + + with threading_helper.start_threads(threads): + pass + + rest2 = comp.flush() + self.assertEqual(rest1, rest2) + actual = b''.join(output) + rest2 + self.assertEqual(expected, actual) + + @unittest.skipUnless(Py_GIL_DISABLED, 'this test can only possibly fail with GIL disabled') + @threading_helper.reap_threads + @threading_helper.requires_working_threading() + def test_decompress_locking(self): + input = compress(b'a'* (16*_1K)) + num_threads = 8 + # to ensure we decompress over multiple calls, set maxsize + window_size = _1K * 16//num_threads + + decomp = ZstdDecompressor() + parts = [] + for _ in range(num_threads): + res = decomp.decompress(input, window_size) + if res: + parts.append(res) + expected = b''.join(parts) + + comp = ZstdDecompressor() + output = [] + def run_method(method, input_data, output_data): + res = method(input_data, window_size) + if res: + output_data.append(res) + threads = [] + + for i in range(num_threads): + thread = threading.Thread(target=run_method, args=(comp.decompress, input, output)) + + threads.append(thread) + + with threading_helper.start_threads(threads): + pass + + actual = b''.join(output) + self.assertEqual(expected, actual) + + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/zipfile/__init__.py b/Lib/zipfile/__init__.py index cfb44f3ed970ee..88356abe8cbaeb 100644 --- a/Lib/zipfile/__init__.py +++ b/Lib/zipfile/__init__.py @@ -31,6 +31,11 @@ except ImportError: lzma = None +try: + from compression import zstd # We may need its compression method +except ImportError: + zstd = None + __all__ = ["BadZipFile", "BadZipfile", "error", "ZIP_STORED", "ZIP_DEFLATED", "ZIP_BZIP2", "ZIP_LZMA", "is_zipfile", "ZipInfo", "ZipFile", "PyZipFile", "LargeZipFile", @@ -58,12 +63,14 @@ class LargeZipFile(Exception): ZIP_DEFLATED = 8 ZIP_BZIP2 = 12 ZIP_LZMA = 14 +ZIP_ZSTANDARD = 93 # Other ZIP compression methods not supported DEFAULT_VERSION = 20 ZIP64_VERSION = 45 BZIP2_VERSION = 46 LZMA_VERSION = 63 +ZSTANDARD_VERSION = 63 # we recognize (but not necessarily support) all features up to that version MAX_EXTRACT_VERSION = 63 @@ -505,6 +512,8 @@ def FileHeader(self, zip64=None): min_version = max(BZIP2_VERSION, min_version) elif self.compress_type == ZIP_LZMA: min_version = max(LZMA_VERSION, min_version) + elif self.compress_type == ZIP_ZSTANDARD: + min_version = max(ZSTANDARD_VERSION, min_version) self.extract_version = max(min_version, self.extract_version) self.create_version = max(min_version, self.create_version) @@ -766,6 +775,7 @@ def decompress(self, data): 14: 'lzma', 18: 'terse', 19: 'lz77', + 93: 'zstd', 97: 'wavpack', 98: 'ppmd', } @@ -785,6 +795,10 @@ def _check_compression(compression): if not lzma: raise RuntimeError( "Compression requires the (missing) lzma module") + elif compression == ZIP_ZSTANDARD: + if not zstd: + raise RuntimeError( + "Compression requires the (missing) compression.zstd module") else: raise NotImplementedError("That compression method is not supported") @@ -798,9 +812,11 @@ def _get_compressor(compress_type, compresslevel=None): if compresslevel is not None: return bz2.BZ2Compressor(compresslevel) return bz2.BZ2Compressor() - # compresslevel is ignored for ZIP_LZMA + # compresslevel is ignored for ZIP_LZMA and ZIP_ZSTANDARD elif compress_type == ZIP_LZMA: return LZMACompressor() + elif compress_type == ZIP_ZSTANDARD: + return zstd.ZstdCompressor() else: return None @@ -815,6 +831,8 @@ def _get_decompressor(compress_type): return bz2.BZ2Decompressor() elif compress_type == ZIP_LZMA: return LZMADecompressor() + elif compress_type == ZIP_ZSTANDARD: + return zstd.ZstdDecompressor() else: descr = compressor_names.get(compress_type) if descr: diff --git a/Makefile.pre.in b/Makefile.pre.in index 3fda59cdcec71b..381b5f46312da7 100644 --- a/Makefile.pre.in +++ b/Makefile.pre.in @@ -2507,7 +2507,7 @@ maninstall: altmaninstall XMLLIBSUBDIRS= xml xml/dom xml/etree xml/parsers xml/sax LIBSUBDIRS= asyncio \ collections \ - compression compression/bz2 compression/gzip \ + compression compression/bz2 compression/gzip compression/zstd \ compression/lzma compression/zlib compression/_common \ concurrent concurrent/futures \ csv \ @@ -2677,6 +2677,7 @@ TESTSUBDIRS= idlelib/idle_test \ test/test_zipfile/_path \ test/test_zoneinfo \ test/test_zoneinfo/data \ + test/test_zstd \ test/tkinterdata \ test/tokenizedata \ test/tracedmodules \ From 4168895822fac42f37f5e9d496e79a29da08fbe2 Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Tue, 29 Apr 2025 19:49:48 -0700 Subject: [PATCH 02/55] Fix byteswarning in test --- Lib/test/test_zstd/test_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Lib/test/test_zstd/test_core.py b/Lib/test/test_zstd/test_core.py index 07463c519eb83a..f7de5decf59362 100644 --- a/Lib/test/test_zstd/test_core.py +++ b/Lib/test/test_zstd/test_core.py @@ -1477,7 +1477,7 @@ def test_as_prefix(self): # V2 mid = len(V1) // 2 V2 = V1[:mid] + \ - (b'a' if V1[mid] != b'a' else b'b') + \ + (b'a' if V1[mid] != int.from_bytes(b'a') else b'b') + \ V1[mid+1:] # compress From a22fa9b931dcbb70785d56c5722e9fd8a127064d Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Sat, 3 May 2025 13:28:19 -0700 Subject: [PATCH 03/55] Remove shape tests --- Lib/test/test_zstd/test_core.py | 188 -------------------------------- 1 file changed, 188 deletions(-) diff --git a/Lib/test/test_zstd/test_core.py b/Lib/test/test_zstd/test_core.py index f7de5decf59362..0721beb8060a91 100644 --- a/Lib/test/test_zstd/test_core.py +++ b/Lib/test/test_zstd/test_core.py @@ -185,194 +185,6 @@ def test_decompress_2x130_1K(self): dat = decompress(DAT_130K_C + DAT_130K_C) self.assertEqual(len(dat), 2 * _130_1K) -class ClassShapeTestCase(unittest.TestCase): - - def test_ZstdCompressor(self): - # class attributes - ZstdCompressor.CONTINUE - ZstdCompressor.FLUSH_BLOCK - ZstdCompressor.FLUSH_FRAME - - # method & me_1Mer - ZstdCompressor() - ZstdCompressor(12, zstd_dict=TRAINED_DICT) - c = ZstdCompressor(level=2, zstd_dict=TRAINED_DICT) - - c.compress(b"123456") - c.compress(b"123456", ZstdCompressor.CONTINUE) - c.compress(data=b"123456", mode=c.CONTINUE) - - c.flush() - c.flush(ZstdCompressor.FLUSH_BLOCK) - c.flush(mode=c.FLUSH_FRAME) - - c.last_mode - - # decompressor method & me_1Mer - with self.assertRaises(AttributeError): - c.decompress(b"") - with self.assertRaises(AttributeError): - c.at_frame_edge - with self.assertRaises(AttributeError): - c.eof - with self.assertRaises(AttributeError): - c.needs_input - - # read only attribute - with self.assertRaises(AttributeError): - c.last_mode = ZstdCompressor.FLUSH_BLOCK - - # name - self.assertIn(".ZstdCompressor", str(type(c))) - - # doesn't support pickle - with self.assertRaises(TypeError): - pickle.dumps(c) - - # supports subclass - class SubClass(ZstdCompressor): - pass - - def test_Decompressor(self): - # method & me_1Mer - ZstdDecompressor() - ZstdDecompressor(TRAINED_DICT, {}) - d = ZstdDecompressor(zstd_dict=TRAINED_DICT, options={}) - - d.decompress(b"") - d.decompress(b"", 100) - d.decompress(data=b"", max_length=100) - - d.eof - d.needs_input - d.unused_data - - # ZstdCompressor attributes - with self.assertRaises(AttributeError): - d.CONTINUE - with self.assertRaises(AttributeError): - d.FLUSH_BLOCK - with self.assertRaises(AttributeError): - d.FLUSH_FRAME - with self.assertRaises(AttributeError): - d.compress(b"") - with self.assertRaises(AttributeError): - d.flush() - - # read only attributes - with self.assertRaises(AttributeError): - d.eof = True - with self.assertRaises(AttributeError): - d.needs_input = True - with self.assertRaises(AttributeError): - d.unused_data = b"" - - # name - self.assertIn(".ZstdDecompressor", str(type(d))) - - # doesn't support pickle - with self.assertRaises(TypeError): - pickle.dumps(d) - - # supports subclass - class SubClass(ZstdDecompressor): - pass - - def test_ZstdDict(self): - ZstdDict(b"12345678", True) - zd = ZstdDict(b"12345678", is_raw=True) - - self.assertEqual(type(zd.dict_content), bytes) - self.assertEqual(zd.dict_id, 0) - self.assertEqual(zd.as_digested_dict[1], 0) - self.assertEqual(zd.as_undigested_dict[1], 1) - self.assertEqual(zd.as_prefix[1], 2) - - # name - self.assertIn(".ZstdDict", str(type(zd))) - - # doesn't support pickle - with self.assertRaisesRegex(TypeError, r"cannot pickle"): - pickle.dumps(zd) - with self.assertRaisesRegex(TypeError, r"cannot pickle"): - pickle.dumps(zd.as_prefix) - - # supports subclass - class SubClass(ZstdDict): - pass - - def test_Strategy(self): - # class attributes - Strategy.fast - Strategy.dfast - Strategy.greedy - Strategy.lazy - Strategy.lazy2 - Strategy.btlazy2 - Strategy.btopt - Strategy.btultra - Strategy.btultra2 - - def test_CParameter(self): - CParameter.compressionLevel - CParameter.windowLog - CParameter.hashLog - CParameter.chainLog - CParameter.searchLog - CParameter.minMatch - CParameter.targetLength - CParameter.strategy - with self.assertRaises(NotImplementedError): - CParameter.targetCBlockSize - - CParameter.enableLongDistanceMatching - CParameter.ldmHashLog - CParameter.ldmMinMatch - CParameter.ldmBucketSizeLog - CParameter.ldmHashRateLog - - CParameter.contentSizeFlag - CParameter.checksumFlag - CParameter.dictIDFlag - - CParameter.nbWorkers - CParameter.jobSize - CParameter.overlapLog - - t = CParameter.windowLog.bounds() - self.assertEqual(len(t), 2) - self.assertEqual(type(t[0]), int) - self.assertEqual(type(t[1]), int) - - def test_DParameter(self): - DParameter.windowLogMax - - t = DParameter.windowLogMax.bounds() - self.assertEqual(len(t), 2) - self.assertEqual(type(t[0]), int) - self.assertEqual(type(t[1]), int) - - def test_zstderror_pickle(self): - try: - decompress(b"invalid data") - except Exception as e: - s = pickle.dumps(e) - obj = pickle.loads(s) - self.assertEqual(type(obj), ZstdError) - else: - self.assertFalse(True, "unreachable code path") - - def test_ZstdFile_extend(self): - # These classes and variables can be used to extend ZstdFile, - # so pin them down. - self.assertTrue(issubclass(ZstdFile, io.BufferedIOBase)) - self.assertIs(ZstdFile._READER_CLASS, _streams.DecompressReader) - - # mode - self.assertEqual(zstdfile._MODE_CLOSED, 0) - self.assertEqual(zstdfile._MODE_READ, 1) - self.assertEqual(zstdfile._MODE_WRITE, 2) - class CompressorTestCase(unittest.TestCase): From e70e03b4e071ef80a673027bde7af61521324175 Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Sat, 3 May 2025 19:10:30 -0700 Subject: [PATCH 04/55] Make namedtuples dataclasses --- Lib/compression/zstd/__init__.py | 45 ++++++++++++++++++-------------- Lib/test/test_zstd/test_core.py | 3 ++- 2 files changed, 28 insertions(+), 20 deletions(-) diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py index 731da1a9598392..affdf814f57006 100644 --- a/Lib/compression/zstd/__init__.py +++ b/Lib/compression/zstd/__init__.py @@ -27,9 +27,9 @@ "ZstdFile", ) -from collections import namedtuple -from enum import IntEnum -from functools import lru_cache +import enum +import functools +import dataclasses from compression.zstd.zstdfile import ZstdFile, open from _zstd import * @@ -43,14 +43,19 @@ _finalize_dict = _zstd._finalize_dict -# TODO(emmatyping): these should be dataclasses or some other class, not namedtuples +@dataclasses.dataclass(frozen=True) +class _CompressionLevelValues: + default: int + min: int + max: int -# compressionLevel_values -_nt_values = namedtuple("values", ["default", "min", "max"]) -compressionLevel_values = _nt_values(*_zstd._compressionLevel_values) +compressionLevel_values = _CompressionLevelValues(*_zstd._compressionLevel_values) - -_nt_frame_info = namedtuple("frame_info", ["decompressed_size", "dictionary_id"]) +@dataclasses.dataclass(frozen=True) +class FrameInfo: + """A dataclass storing information about a Zstandard frame.""" + decompressed_size: int + dictionary_id: int def get_frame_info(frame_buffer): @@ -61,18 +66,20 @@ def get_frame_info(frame_buffer): a frame, and needs to include at least the frame header (6 to 18 bytes). - Return a two-items namedtuple: (decompressed_size, dictionary_id) + Return a FrameInfo dataclass, which currently has two attributes + + 'decompressed_size' is the size in bytes of the data in the frame when + decompressed. If decompressed_size is None, decompressed size is unknown. - dictionary_id is a 32-bit unsigned integer value. 0 means dictionary ID was + 'dictionary_id' is a 32-bit unsigned integer value. 0 means dictionary ID was not recorded in the frame header, the frame may or may not need a dictionary to be decoded, and the ID of such a dictionary is not specified. - - It's possible to append more items to the namedtuple in the future.""" + """ ret_tuple = _zstd._get_frame_info(frame_buffer) - return _nt_frame_info(*ret_tuple) + return FrameInfo(*ret_tuple) def _nbytes(dat): @@ -215,7 +222,7 @@ def __get__(self, *_, **__): raise NotImplementedError(msg) -class CParameter(IntEnum): +class CParameter(enum.IntEnum): """Compression parameters""" compressionLevel = _zstd._ZSTD_c_compressionLevel @@ -243,26 +250,26 @@ class CParameter(IntEnum): jobSize = _zstd._ZSTD_c_jobSize overlapLog = _zstd._ZSTD_c_overlapLog - @lru_cache(maxsize=None) + @functools.lru_cache(maxsize=None) def bounds(self): """Return lower and upper bounds of a compression parameter, both inclusive.""" # 1 means compression parameter return _zstd._get_param_bounds(1, self.value) -class DParameter(IntEnum): +class DParameter(enum.IntEnum): """Decompression parameters""" windowLogMax = _zstd._ZSTD_d_windowLogMax - @lru_cache(maxsize=None) + @functools.lru_cache(maxsize=None) def bounds(self): """Return lower and upper bounds of a decompression parameter, both inclusive.""" # 0 means decompression parameter return _zstd._get_param_bounds(0, self.value) -class Strategy(IntEnum): +class Strategy(enum.IntEnum): """Compression strategies, listed from fastest to strongest. Note : new strategies _might_ be added in the future, only the order diff --git a/Lib/test/test_zstd/test_core.py b/Lib/test/test_zstd/test_core.py index 0721beb8060a91..0aa8b014483faa 100644 --- a/Lib/test/test_zstd/test_core.py +++ b/Lib/test/test_zstd/test_core.py @@ -149,7 +149,8 @@ def test_roundtrip_default(self): def test_roundtrip_level(self): raw_dat = THIS_FILE_BYTES[: len(THIS_FILE_BYTES) // 6] - _default, minv, maxv = compressionLevel_values + minv = compressionLevel_values.min + maxv = compressionLevel_values.max for level in range(max(-20, minv), maxv + 1): dat1 = compress(raw_dat, level) From cbf0ef83aa6ac50700f66094636d9352bafe3f5a Mon Sep 17 00:00:00 2001 From: Emma Smith Date: Sun, 4 May 2025 09:05:33 -0700 Subject: [PATCH 05/55] Apply suggestions from AA-Turner Co-authored-by: Adam Turner <9087854+AA-Turner@users.noreply.github.com> --- Lib/compression/zstd/__init__.py | 92 ++++++++++++++++++-------------- 1 file changed, 52 insertions(+), 40 deletions(-) diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py index affdf814f57006..4525ab2cea46b3 100644 --- a/Lib/compression/zstd/__init__.py +++ b/Lib/compression/zstd/__init__.py @@ -3,28 +3,30 @@ """ __all__ = ( - # From this file + # compression.zstd "compressionLevel_values", - "get_frame_info", + "compress", "CParameter", + "decompress", "DParameter", - "Strategy", "finalize_dict", + "get_frame_info", + "Strategy", "train_dict", "zstd_support_multithread", - "compress", - "decompress", - # From _zstd + + # compression.zstd.zstdfile + "open", + "ZstdFile", + + # _zstd + "get_frame_size", + "zstd_version", + "zstd_version_info", "ZstdCompressor", "ZstdDecompressor", "ZstdDict", "ZstdError", - "get_frame_size", - "zstd_version", - "zstd_version_info", - # From zstd.zstdfile - "open", - "ZstdFile", ) import enum @@ -43,43 +45,55 @@ _finalize_dict = _zstd._finalize_dict -@dataclasses.dataclass(frozen=True) -class _CompressionLevelValues: - default: int - min: int - max: int +class _CLValues: + __slots__ = 'default', 'min', 'max' + + def __init__(self, default, min, max): + super().__setattr__('default', default) + super().__setattr__('min', min) + super().__setattr__('max', max) + + def __repr__(self): + return (f'compression_level_values(default={self.default}, ' + f'min={self.min}, max={self.max})') + + def __setattr__(self, name, _): + raise AttributeError(f"can't set attribute {name!r}") compressionLevel_values = _CompressionLevelValues(*_zstd._compressionLevel_values) -@dataclasses.dataclass(frozen=True) class FrameInfo: - """A dataclass storing information about a Zstandard frame.""" - decompressed_size: int - dictionary_id: int + """Information about a Zstandard frame.""" + __slots__ = 'decompressed_size', 'dictionary_id' + + def __init__(self, decompressed_size, dictionary_id): + super().__setattr__('decompressed_size', decompressed_size) + super().__setattr__('dictionary_id', dictionary_id) + + def __repr__(self): + return (f'FrameInfo(decompressed_size={self.decompressed_size}, ' + f'dictionary_id={self.dictionary_id})') + + def __setattr__(self, name, _): + raise AttributeError(f"can't set attribute {name!r}") def get_frame_info(frame_buffer): """Get zstd frame information from a frame header. - Parameter - frame_buffer: A bytes-like object. It should starts from the beginning of - a frame, and needs to include at least the frame header (6 to - 18 bytes). + *frame_buffer* is a bytes-like object. It should starts from the beginning of + a frame, and needs to include at least the frame header (6 to 18 bytes). Return a FrameInfo dataclass, which currently has two attributes 'decompressed_size' is the size in bytes of the data in the frame when - decompressed. - - If decompressed_size is None, decompressed size is unknown. + decompressed, or None when the decompressed size is unknown. 'dictionary_id' is a 32-bit unsigned integer value. 0 means dictionary ID was not recorded in the frame header, the frame may or may not need a dictionary to be decoded, and the ID of such a dictionary is not specified. """ - - ret_tuple = _zstd._get_frame_info(frame_buffer) - return FrameInfo(*ret_tuple) + return FrameInfo(*_zstd._get_frame_info(frame_buffer)) def _nbytes(dat): @@ -92,14 +106,13 @@ def _nbytes(dat): def train_dict(samples, dict_size): """Train a zstd dictionary, return a ZstdDict object. - Parameters - samples: An iterable of samples, a sample is a bytes-like object - represents a file. - dict_size: The dictionary's maximum size, in bytes. + *samples* is an iterable of samples, where a sample is a bytes-like + object representing a file. + *dict_size* is the dictionary's maximum size, in bytes. """ - # Check argument's type if not isinstance(dict_size, int): - raise TypeError('dict_size argument should be an int object.') + ds_cls = type(dict_size).__qualname__ + raise TypeError('dict_size must be an int object, not {ds_cls!r}.') # Prepare data chunks = [] @@ -169,11 +182,10 @@ def finalize_dict(zstd_dict, samples, dict_size, level): dict_content = _finalize_dict(zstd_dict.dict_content, chunks, chunk_sizes, dict_size, level) - - return _zstd.ZstdDict(dict_content) + return ZstdDict(dict_content) def compress(data, level=None, options=None, zstd_dict=None): - """Compress a block of data, return a bytes object of zstd compressed data. + """Return Zstandard compressed *data* as bytes. Refer to ZstdCompressor's docstring for a description of the optional arguments *level*, *options*, and *zstd_dict*. From 298b369375477cc7e71e8e8e10af550948251799 Mon Sep 17 00:00:00 2001 From: Emma Smith Date: Sun, 4 May 2025 09:09:19 -0700 Subject: [PATCH 06/55] Clean up chunk calculations in train_dict Co-authored-by: Adam Turner <9087854+AA-Turner@users.noreply.github.com> --- Lib/compression/zstd/__init__.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py index 4525ab2cea46b3..83d1e10cc5d1ea 100644 --- a/Lib/compression/zstd/__init__.py +++ b/Lib/compression/zstd/__init__.py @@ -114,20 +114,11 @@ def train_dict(samples, dict_size): ds_cls = type(dict_size).__qualname__ raise TypeError('dict_size must be an int object, not {ds_cls!r}.') - # Prepare data - chunks = [] - chunk_sizes = [] - for chunk in samples: - chunks.append(chunk) - chunk_sizes.append(_nbytes(chunk)) - - chunks = b''.join(chunks) + samples = tuple(samples) + chunks = b''.join(samples) + chunk_sizes = tuple(map(_nbytes, samples)) if not chunks: - raise ValueError("The samples are empty content, can't train dictionary.") - - # samples_bytes: samples be stored concatenated in a single flat buffer. - # samples_size_list: a list of each sample's size. - # dict_size: size of the dictionary, in bytes. + raise ValueError("samples contained no data; can't train dictionary.") dict_content = _train_dict(chunks, chunk_sizes, dict_size) return ZstdDict(dict_content) From a22be68041a140777fbd7b2a3821d5f2d6439992 Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Sun, 4 May 2025 09:35:43 -0700 Subject: [PATCH 07/55] Fix _CLValues instantiation --- Lib/compression/zstd/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py index 83d1e10cc5d1ea..506361b84ebbb5 100644 --- a/Lib/compression/zstd/__init__.py +++ b/Lib/compression/zstd/__init__.py @@ -60,7 +60,7 @@ def __repr__(self): def __setattr__(self, name, _): raise AttributeError(f"can't set attribute {name!r}") -compressionLevel_values = _CompressionLevelValues(*_zstd._compressionLevel_values) +compressionLevel_values = _CLValues(*_zstd._compressionLevel_values) class FrameInfo: """Information about a Zstandard frame.""" From b30ed02b480cc246ac1f94120bceabb7619f257d Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Sun, 4 May 2025 09:46:16 -0700 Subject: [PATCH 08/55] More cleanup of train_/finalize_dict Co-authored-by: Adam Turner <9087854+AA-Turner@users.noreply.github.com> --- Lib/compression/zstd/__init__.py | 35 ++++++++++---------------------- 1 file changed, 11 insertions(+), 24 deletions(-) diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py index 506361b84ebbb5..e20eb8dc559704 100644 --- a/Lib/compression/zstd/__init__.py +++ b/Lib/compression/zstd/__init__.py @@ -31,7 +31,6 @@ import enum import functools -import dataclasses from compression.zstd.zstdfile import ZstdFile, open from _zstd import * @@ -116,7 +115,7 @@ def train_dict(samples, dict_size): samples = tuple(samples) chunks = b''.join(samples) - chunk_sizes = tuple(map(_nbytes, samples)) + chunk_sizes = tuple(_nbytes(sample) for sample in samples) if not chunks: raise ValueError("samples contained no data; can't train dictionary.") dict_content = _train_dict(chunks, chunk_sizes, dict_size) @@ -136,14 +135,13 @@ def finalize_dict(zstd_dict, samples, dict_size, level): dictionary can be a "raw content" dictionary, see is_raw parameter in ZstdDict.__init__ method. - Parameters - zstd_dict: A ZstdDict object, basis dictionary. - samples: An iterable of samples, a sample is a bytes-like object - represents a file. - dict_size: The dictionary's maximum size, in bytes. - level: The compression level expected to use in production. The - statistics for each compression level differ, so tuning the - dictionary for the compression level can help quite a bit. + *zstd_dict* is a ZstdDict object, the basis dictionary. + *samples* is an iterable of samples, a sample is a bytes-like object + representing a file. + *dict_size* is the dictionary's maximum size, in bytes. + *level* is the compression level expected to use in production. The + statistics for each compression level differ, so tuning the + dictionary for the compression level can help quite a bit. """ # Check arguments' type @@ -154,22 +152,11 @@ def finalize_dict(zstd_dict, samples, dict_size, level): if not isinstance(level, int): raise TypeError('level argument should be an int object.') - # Prepare data - chunks = [] - chunk_sizes = [] - for chunk in samples: - chunks.append(chunk) - chunk_sizes.append(_nbytes(chunk)) - - chunks = b''.join(chunks) + samples = tuple(samples) + chunks = b''.join(samples) + chunk_sizes = tuple(_nbytes(sample) for sample in samples) if not chunks: raise ValueError("The samples are empty content, can't finalize dictionary.") - - # custom_dict_bytes: existing dictionary. - # samples_bytes: samples be stored concatenated in a single flat buffer. - # samples_size_list: a list of each sample's size. - # dict_size: maximal size of the dictionary, in bytes. - # compression_level: compression level expected to use in production. dict_content = _finalize_dict(zstd_dict.dict_content, chunks, chunk_sizes, dict_size, level) From 307a89406668c3c9a1972c8578e6b83643778c9b Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Sun, 4 May 2025 10:03:17 -0700 Subject: [PATCH 09/55] Have train_/finalize_dict take tuple not list --- Lib/test/test_zstd/test_core.py | 59 +++++++++------------------- Modules/_zstd/_zstdmodule.c | 33 ++++++++-------- Modules/_zstd/clinic/_zstdmodule.c.h | 38 +++++++++--------- 3 files changed, 53 insertions(+), 77 deletions(-) diff --git a/Lib/test/test_zstd/test_core.py b/Lib/test/test_zstd/test_core.py index 0aa8b014483faa..b88779685ba2db 100644 --- a/Lib/test/test_zstd/test_core.py +++ b/Lib/test/test_zstd/test_core.py @@ -1128,9 +1128,6 @@ def test_train_dict(self): self.assertEqual(sample, dat2) def test_finalize_dict(self): - if zstd_version_info < (1, 4, 5): - return - DICT_SIZE2 = 200*1024 C_LEVEL = 6 @@ -1174,16 +1171,8 @@ def test_train_dict_arguments(self): train_dict(SAMPLES, 0) def test_finalize_dict_arguments(self): - if zstd_version_info < (1, 4, 5): - with self.assertRaises(NotImplementedError): - finalize_dict({1:2}, [b'aaa', b'bbb'], 100*_1K, 2) - return - - try: - finalize_dict(TRAINED_DICT, SAMPLES, 1*_1M, 2) - except NotImplementedError: - # < v1.4.5 at compile-time, >= v.1.4.5 at run-time - return + with self.assertRaises(TypeError): + finalize_dict({1:2}, (b'aaa', b'bbb'), 100*_1K, 2) with self.assertRaises(ValueError): finalize_dict(TRAINED_DICT, [], 100*_1K, 2) @@ -1197,51 +1186,43 @@ def test_finalize_dict_arguments(self): def test_train_dict_c(self): # argument wrong type with self.assertRaises(TypeError): - _zstd._train_dict({}, [], 100) + _zstd._train_dict({}, (), 100) with self.assertRaises(TypeError): _zstd._train_dict(b'', 99, 100) with self.assertRaises(TypeError): - _zstd._train_dict(b'', [], 100.1) + _zstd._train_dict(b'', (), 100.1) # size > size_t with self.assertRaises(ValueError): - _zstd._train_dict(b'', [2**64+1], 100) + _zstd._train_dict(b'', (2**64+1,), 100) # dict_size <= 0 with self.assertRaises(ValueError): - _zstd._train_dict(b'', [], 0) + _zstd._train_dict(b'', (), 0) def test_finalize_dict_c(self): - if zstd_version_info < (1, 4, 5): - with self.assertRaises(NotImplementedError): - _zstd._finalize_dict(1, 2, 3, 4, 5) - return - - try: - _zstd._finalize_dict(TRAINED_DICT.dict_content, b'123', [3,], 1*_1M, 5) - except NotImplementedError: - # < v1.4.5 at compile-time, >= v.1.4.5 at run-time - return + with self.assertRaises(TypeError): + _zstd._finalize_dict(1, 2, 3, 4, 5) # argument wrong type with self.assertRaises(TypeError): - _zstd._finalize_dict({}, b'', [], 100, 5) + _zstd._finalize_dict({}, b'', (), 100, 5) with self.assertRaises(TypeError): - _zstd._finalize_dict(TRAINED_DICT.dict_content, {}, [], 100, 5) + _zstd._finalize_dict(TRAINED_DICT.dict_content, {}, (), 100, 5) with self.assertRaises(TypeError): _zstd._finalize_dict(TRAINED_DICT.dict_content, b'', 99, 100, 5) with self.assertRaises(TypeError): - _zstd._finalize_dict(TRAINED_DICT.dict_content, b'', [], 100.1, 5) + _zstd._finalize_dict(TRAINED_DICT.dict_content, b'', (), 100.1, 5) with self.assertRaises(TypeError): - _zstd._finalize_dict(TRAINED_DICT.dict_content, b'', [], 100, 5.1) + _zstd._finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 5.1) # size > size_t with self.assertRaises(ValueError): - _zstd._finalize_dict(TRAINED_DICT.dict_content, b'', [2**64+1], 100, 5) + _zstd._finalize_dict(TRAINED_DICT.dict_content, b'', (2**64+1,), 100, 5) # dict_size <= 0 with self.assertRaises(ValueError): - _zstd._finalize_dict(TRAINED_DICT.dict_content, b'', [], 0, 5) + _zstd._finalize_dict(TRAINED_DICT.dict_content, b'', (), 0, 5) def test_train_buffer_protocol_samples(self): def _nbytes(dat): @@ -1263,24 +1244,20 @@ def _nbytes(dat): # wrong size list with self.assertRaisesRegex(ValueError, "The samples size list doesn't match the concatenation's size"): - _zstd._train_dict(concatenation, wrong_size_lst, 100*1024) + _zstd._train_dict(concatenation, tuple(wrong_size_lst), 100*1024) # correct size list - _zstd._train_dict(concatenation, correct_size_lst, 3*1024) - - # test _finalize_dict - if zstd_version_info < (1, 4, 5): - return + _zstd._train_dict(concatenation, tuple(correct_size_lst), 3*1024) # wrong size list with self.assertRaisesRegex(ValueError, "The samples size list doesn't match the concatenation's size"): _zstd._finalize_dict(TRAINED_DICT.dict_content, - concatenation, wrong_size_lst, 300*1024, 5) + concatenation, tuple(wrong_size_lst), 300*1024, 5) # correct size list _zstd._finalize_dict(TRAINED_DICT.dict_content, - concatenation, correct_size_lst, 300*1024, 5) + concatenation, tuple(correct_size_lst), 300*1024, 5) def test_as_prefix(self): # V1 diff --git a/Modules/_zstd/_zstdmodule.c b/Modules/_zstd/_zstdmodule.c index 18dc13b3fd16f0..994f151f1f449b 100644 --- a/Modules/_zstd/_zstdmodule.c +++ b/Modules/_zstd/_zstdmodule.c @@ -180,8 +180,8 @@ _zstd._train_dict samples_bytes: PyBytesObject Concatenation of samples. - samples_size_list: object(subclass_of='&PyList_Type') - List of samples' sizes. + samples_sizes: object(subclass_of='&PyTuple_Type') + Tuple of samples' sizes. dict_size: Py_ssize_t The size of the dictionary. / @@ -191,8 +191,8 @@ Internal function, train a zstd dictionary on sample data. static PyObject * _zstd__train_dict_impl(PyObject *module, PyBytesObject *samples_bytes, - PyObject *samples_size_list, Py_ssize_t dict_size) -/*[clinic end generated code: output=ee53c34c8f77886b input=b21d092c695a3a81]*/ + PyObject *samples_sizes, Py_ssize_t dict_size) +/*[clinic end generated code: output=b5b4f36347c0addd input=2dce5b57d63923e2]*/ { // TODO(emmatyping): The preamble and suffix to this function and _finalize_dict // are pretty similar. We should see if we can refactor them to share that code. @@ -209,7 +209,7 @@ _zstd__train_dict_impl(PyObject *module, PyBytesObject *samples_bytes, return NULL; } - chunks_number = Py_SIZE(samples_size_list); + chunks_number = Py_SIZE(samples_sizes); if ((size_t) chunks_number > UINT32_MAX) { PyErr_Format(PyExc_ValueError, "The number of samples should be <= %u.", UINT32_MAX); @@ -225,12 +225,11 @@ _zstd__train_dict_impl(PyObject *module, PyBytesObject *samples_bytes, sizes_sum = 0; for (i = 0; i < chunks_number; i++) { - PyObject *size = PyList_GetItemRef(samples_size_list, i); + PyObject *size = PyTuple_GetItem(samples_sizes, i); chunk_sizes[i] = PyLong_AsSize_t(size); - Py_DECREF(size); if (chunk_sizes[i] == (size_t)-1 && PyErr_Occurred()) { PyErr_Format(PyExc_ValueError, - "Items in samples_size_list should be an int " + "Items in samples_sizes should be an int " "object, with a value between 0 and %u.", SIZE_MAX); goto error; } @@ -239,7 +238,7 @@ _zstd__train_dict_impl(PyObject *module, PyBytesObject *samples_bytes, if (sizes_sum != Py_SIZE(samples_bytes)) { PyErr_SetString(PyExc_ValueError, - "The samples size list doesn't match the concatenation's size."); + "The samples size tuple doesn't match the concatenation's size."); goto error; } @@ -287,8 +286,8 @@ _zstd._finalize_dict Custom dictionary content. samples_bytes: PyBytesObject Concatenation of samples. - samples_size_list: object(subclass_of='&PyList_Type') - List of samples' sizes. + samples_sizes: object(subclass_of='&PyTuple_Type') + Tuple of samples' sizes. dict_size: Py_ssize_t The size of the dictionary. compression_level: int @@ -301,9 +300,9 @@ Internal function, finalize a zstd dictionary. static PyObject * _zstd__finalize_dict_impl(PyObject *module, PyBytesObject *custom_dict_bytes, PyBytesObject *samples_bytes, - PyObject *samples_size_list, Py_ssize_t dict_size, + PyObject *samples_sizes, Py_ssize_t dict_size, int compression_level) -/*[clinic end generated code: output=9c2a7d8c845cee93 input=08531a803d87c56f]*/ +/*[clinic end generated code: output=5dc5b520fddba37f input=8afd42a249078460]*/ { Py_ssize_t chunks_number; size_t *chunk_sizes = NULL; @@ -319,7 +318,7 @@ _zstd__finalize_dict_impl(PyObject *module, PyBytesObject *custom_dict_bytes, return NULL; } - chunks_number = Py_SIZE(samples_size_list); + chunks_number = Py_SIZE(samples_sizes); if ((size_t) chunks_number > UINT32_MAX) { PyErr_Format(PyExc_ValueError, "The number of samples should be <= %u.", UINT32_MAX); @@ -335,11 +334,11 @@ _zstd__finalize_dict_impl(PyObject *module, PyBytesObject *custom_dict_bytes, sizes_sum = 0; for (i = 0; i < chunks_number; i++) { - PyObject *size = PyList_GET_ITEM(samples_size_list, i); + PyObject *size = PyTuple_GetItem(samples_sizes, i); chunk_sizes[i] = PyLong_AsSize_t(size); if (chunk_sizes[i] == (size_t)-1 && PyErr_Occurred()) { PyErr_Format(PyExc_ValueError, - "Items in samples_size_list should be an int " + "Items in samples_sizes should be an int " "object, with a value between 0 and %u.", SIZE_MAX); goto error; } @@ -348,7 +347,7 @@ _zstd__finalize_dict_impl(PyObject *module, PyBytesObject *custom_dict_bytes, if (sizes_sum != Py_SIZE(samples_bytes)) { PyErr_SetString(PyExc_ValueError, - "The samples size list doesn't match the concatenation's size."); + "The samples size tuple doesn't match the concatenation's size."); goto error; } diff --git a/Modules/_zstd/clinic/_zstdmodule.c.h b/Modules/_zstd/clinic/_zstdmodule.c.h index 4b78bded67bca7..94f14ed858cdae 100644 --- a/Modules/_zstd/clinic/_zstdmodule.c.h +++ b/Modules/_zstd/clinic/_zstdmodule.c.h @@ -10,15 +10,15 @@ preserve #include "pycore_modsupport.h" // _PyArg_CheckPositional() PyDoc_STRVAR(_zstd__train_dict__doc__, -"_train_dict($module, samples_bytes, samples_size_list, dict_size, /)\n" +"_train_dict($module, samples_bytes, samples_sizes, dict_size, /)\n" "--\n" "\n" "Internal function, train a zstd dictionary on sample data.\n" "\n" " samples_bytes\n" " Concatenation of samples.\n" -" samples_size_list\n" -" List of samples\' sizes.\n" +" samples_sizes\n" +" Tuple of samples\' sizes.\n" " dict_size\n" " The size of the dictionary."); @@ -27,14 +27,14 @@ PyDoc_STRVAR(_zstd__train_dict__doc__, static PyObject * _zstd__train_dict_impl(PyObject *module, PyBytesObject *samples_bytes, - PyObject *samples_size_list, Py_ssize_t dict_size); + PyObject *samples_sizes, Py_ssize_t dict_size); static PyObject * _zstd__train_dict(PyObject *module, PyObject *const *args, Py_ssize_t nargs) { PyObject *return_value = NULL; PyBytesObject *samples_bytes; - PyObject *samples_size_list; + PyObject *samples_sizes; Py_ssize_t dict_size; if (!_PyArg_CheckPositional("_train_dict", nargs, 3, 3)) { @@ -45,11 +45,11 @@ _zstd__train_dict(PyObject *module, PyObject *const *args, Py_ssize_t nargs) goto exit; } samples_bytes = (PyBytesObject *)args[0]; - if (!PyList_Check(args[1])) { - _PyArg_BadArgument("_train_dict", "argument 2", "list", args[1]); + if (!PyTuple_Check(args[1])) { + _PyArg_BadArgument("_train_dict", "argument 2", "tuple", args[1]); goto exit; } - samples_size_list = args[1]; + samples_sizes = args[1]; { Py_ssize_t ival = -1; PyObject *iobj = _PyNumber_Index(args[2]); @@ -62,7 +62,7 @@ _zstd__train_dict(PyObject *module, PyObject *const *args, Py_ssize_t nargs) } dict_size = ival; } - return_value = _zstd__train_dict_impl(module, samples_bytes, samples_size_list, dict_size); + return_value = _zstd__train_dict_impl(module, samples_bytes, samples_sizes, dict_size); exit: return return_value; @@ -70,7 +70,7 @@ _zstd__train_dict(PyObject *module, PyObject *const *args, Py_ssize_t nargs) PyDoc_STRVAR(_zstd__finalize_dict__doc__, "_finalize_dict($module, custom_dict_bytes, samples_bytes,\n" -" samples_size_list, dict_size, compression_level, /)\n" +" samples_sizes, dict_size, compression_level, /)\n" "--\n" "\n" "Internal function, finalize a zstd dictionary.\n" @@ -79,8 +79,8 @@ PyDoc_STRVAR(_zstd__finalize_dict__doc__, " Custom dictionary content.\n" " samples_bytes\n" " Concatenation of samples.\n" -" samples_size_list\n" -" List of samples\' sizes.\n" +" samples_sizes\n" +" Tuple of samples\' sizes.\n" " dict_size\n" " The size of the dictionary.\n" " compression_level\n" @@ -92,7 +92,7 @@ PyDoc_STRVAR(_zstd__finalize_dict__doc__, static PyObject * _zstd__finalize_dict_impl(PyObject *module, PyBytesObject *custom_dict_bytes, PyBytesObject *samples_bytes, - PyObject *samples_size_list, Py_ssize_t dict_size, + PyObject *samples_sizes, Py_ssize_t dict_size, int compression_level); static PyObject * @@ -101,7 +101,7 @@ _zstd__finalize_dict(PyObject *module, PyObject *const *args, Py_ssize_t nargs) PyObject *return_value = NULL; PyBytesObject *custom_dict_bytes; PyBytesObject *samples_bytes; - PyObject *samples_size_list; + PyObject *samples_sizes; Py_ssize_t dict_size; int compression_level; @@ -118,11 +118,11 @@ _zstd__finalize_dict(PyObject *module, PyObject *const *args, Py_ssize_t nargs) goto exit; } samples_bytes = (PyBytesObject *)args[1]; - if (!PyList_Check(args[2])) { - _PyArg_BadArgument("_finalize_dict", "argument 3", "list", args[2]); + if (!PyTuple_Check(args[2])) { + _PyArg_BadArgument("_finalize_dict", "argument 3", "tuple", args[2]); goto exit; } - samples_size_list = args[2]; + samples_sizes = args[2]; { Py_ssize_t ival = -1; PyObject *iobj = _PyNumber_Index(args[3]); @@ -139,7 +139,7 @@ _zstd__finalize_dict(PyObject *module, PyObject *const *args, Py_ssize_t nargs) if (compression_level == -1 && PyErr_Occurred()) { goto exit; } - return_value = _zstd__finalize_dict_impl(module, custom_dict_bytes, samples_bytes, samples_size_list, dict_size, compression_level); + return_value = _zstd__finalize_dict_impl(module, custom_dict_bytes, samples_bytes, samples_sizes, dict_size, compression_level); exit: return return_value; @@ -429,4 +429,4 @@ _zstd__set_parameter_types(PyObject *module, PyObject *const *args, Py_ssize_t n exit: return return_value; } -/*[clinic end generated code: output=077c8ea2b11fb188 input=a9049054013a1b77]*/ +/*[clinic end generated code: output=f4530f3e3439cbe7 input=a9049054013a1b77]*/ From dd716a4c800806c53a26f22ba08c7bebe60a1cba Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Sun, 4 May 2025 10:32:02 -0700 Subject: [PATCH 10/55] Ensure trailing data raises errors --- Lib/compression/zstd/__init__.py | 9 +-------- Lib/test/test_zstd/test_core.py | 24 ++++++++---------------- 2 files changed, 9 insertions(+), 24 deletions(-) diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py index e20eb8dc559704..653be5fb307a87 100644 --- a/Lib/compression/zstd/__init__.py +++ b/Lib/compression/zstd/__init__.py @@ -184,14 +184,7 @@ def decompress(data, zstd_dict=None, options=None): results = [] while True: decomp = ZstdDecompressor(options=options, zstd_dict=zstd_dict) - try: - res = decomp.decompress(data) - except ZstdError: - if results: - break # Leftover data is not a valid LZMA/XZ stream; ignore it. - else: - raise # Error on the first iteration; bail out. - results.append(res) + results.append(decomp.decompress(data)) if not decomp.eof: raise ZstdError("Compressed data ended before the " "end-of-stream marker was reached") diff --git a/Lib/test/test_zstd/test_core.py b/Lib/test/test_zstd/test_core.py index b88779685ba2db..b5c1782a06f55e 100644 --- a/Lib/test/test_zstd/test_core.py +++ b/Lib/test/test_zstd/test_core.py @@ -736,15 +736,11 @@ def test_function_decompress(self): with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"): decompress(b'aaaaaaaaa') - self.assertEqual( - decompress(self.FRAME_42 + b'aaaaaaaaa'), - self.DECOMPRESSED_42 - ) + with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"): + decompress(self.FRAME_42 + b'aaaaaaaaa') - self.assertEqual( - decompress(self.UNKNOWN_FRAME_42_60 + b'aaaaaaaaa'), - self.DECOMPRESSED_42_60 - ) + with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"): + decompress(self.UNKNOWN_FRAME_42_60 + b'aaaaaaaaa') # doesn't match checksum checksum = DAT_130K_C[-4:] @@ -803,15 +799,11 @@ def test_function_skippable(self): with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"): decompress(b'aaaaaaaaa' + SKIPPABLE_FRAME) - self.assertEqual( - decompress(SKIPPABLE_FRAME + b'aaaaaaaaa'), - b'' - ) + with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"): + decompress(SKIPPABLE_FRAME + b'aaaaaaaaa') - self.assertEqual( - decompress(SKIPPABLE_FRAME + SKIPPABLE_FRAME + b'aaaaaaaaa'), - b'' - ) + with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"): + decompress(SKIPPABLE_FRAME + SKIPPABLE_FRAME + b'aaaaaaaaa') def test_decompressor_1(self): # empty 1 From 9b4765b23b594a256d31fc25b32519a7e88f93da Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Sun, 4 May 2025 10:34:00 -0700 Subject: [PATCH 11/55] Remove paramter bounds caching and unsupported... parameters. --- Lib/compression/zstd/__init__.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py index 653be5fb307a87..6a3b383d27c9b0 100644 --- a/Lib/compression/zstd/__init__.py +++ b/Lib/compression/zstd/__init__.py @@ -30,7 +30,6 @@ ) import enum -import functools from compression.zstd.zstdfile import ZstdFile, open from _zstd import * @@ -193,17 +192,6 @@ def decompress(data, zstd_dict=None, options=None): break return b"".join(results) -class _UnsupportedCParameter: - def __set_name__(self, _, name): - self.name = name - - def __get__(self, *_, **__): - msg = ("%s CParameter not available, zstd version is %s.") % ( - self.name, - zstd_version, - ) - raise NotImplementedError(msg) - class CParameter(enum.IntEnum): """Compression parameters""" @@ -217,8 +205,6 @@ class CParameter(enum.IntEnum): targetLength = _zstd._ZSTD_c_targetLength strategy = _zstd._ZSTD_c_strategy - targetCBlockSize = _UnsupportedCParameter() - enableLongDistanceMatching = _zstd._ZSTD_c_enableLongDistanceMatching ldmHashLog = _zstd._ZSTD_c_ldmHashLog ldmMinMatch = _zstd._ZSTD_c_ldmMinMatch @@ -233,7 +219,6 @@ class CParameter(enum.IntEnum): jobSize = _zstd._ZSTD_c_jobSize overlapLog = _zstd._ZSTD_c_overlapLog - @functools.lru_cache(maxsize=None) def bounds(self): """Return lower and upper bounds of a compression parameter, both inclusive.""" # 1 means compression parameter @@ -245,7 +230,6 @@ class DParameter(enum.IntEnum): windowLogMax = _zstd._ZSTD_d_windowLogMax - @functools.lru_cache(maxsize=None) def bounds(self): """Return lower and upper bounds of a decompression parameter, both inclusive.""" # 0 means decompression parameter From 214cd60c539512b2897263f602731a34d0d22046 Mon Sep 17 00:00:00 2001 From: Emma Smith Date: Sun, 4 May 2025 10:46:03 -0700 Subject: [PATCH 12/55] Use kwargs for code clarity Co-authored-by: Adam Turner <9087854+AA-Turner@users.noreply.github.com> --- Lib/compression/zstd/__init__.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py index 6a3b383d27c9b0..5d721cd519e993 100644 --- a/Lib/compression/zstd/__init__.py +++ b/Lib/compression/zstd/__init__.py @@ -170,7 +170,7 @@ def compress(data, level=None, options=None, zstd_dict=None): For incremental compression, use an ZstdCompressor instead. """ comp = ZstdCompressor(level=level, options=options, zstd_dict=zstd_dict) - return comp.compress(data, ZstdCompressor.FLUSH_FRAME) + return comp.compress(data, mode=ZstdCompressor.FLUSH_FRAME) def decompress(data, zstd_dict=None, options=None): """Decompress one or more frames of data. @@ -221,8 +221,7 @@ class CParameter(enum.IntEnum): def bounds(self): """Return lower and upper bounds of a compression parameter, both inclusive.""" - # 1 means compression parameter - return _zstd._get_param_bounds(1, self.value) + return _zstd._get_param_bounds(is_compress=True, parameter=self.value) class DParameter(enum.IntEnum): @@ -232,8 +231,7 @@ class DParameter(enum.IntEnum): def bounds(self): """Return lower and upper bounds of a decompression parameter, both inclusive.""" - # 0 means decompression parameter - return _zstd._get_param_bounds(0, self.value) + return _zstd._get_param_bounds(is_compress=False, parameter=self.value) class Strategy(enum.IntEnum): From e1f53b17cec96abfa9c1f7da1b449f8ee4fabb5f Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Sun, 4 May 2025 10:54:04 -0700 Subject: [PATCH 13/55] Clean up imports in zstd tests --- Lib/test/test_zstd/test_core.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/Lib/test/test_zstd/test_core.py b/Lib/test/test_zstd/test_core.py index b5c1782a06f55e..edb5095bab4ba9 100644 --- a/Lib/test/test_zstd/test_core.py +++ b/Lib/test/test_zstd/test_core.py @@ -2,7 +2,6 @@ import gc import io import pathlib -import pickle import random import builtins import re @@ -11,17 +10,16 @@ import tempfile import threading -from compression._common import _streams - from test.support.import_helper import import_module from test.support import threading_helper from test.support import _1M from test.support import Py_GIL_DISABLED -zstd = import_module("compression.zstd") _zstd = import_module("_zstd") +zstd = import_module("compression.zstd") + from compression.zstd import ( - zstdfile, + open, compress, decompress, ZstdCompressor, @@ -41,7 +39,6 @@ ZstdFile, zstd_support_multithread, ) -from compression.zstd.zstdfile import open _1K = 1024 _130_1K = 130 * _1K From 4c000262ee93b34544dcf962cd1521982f00f790 Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Sun, 4 May 2025 10:58:03 -0700 Subject: [PATCH 14/55] Use _1K instead of 1024 in tests --- Lib/test/test_zstd/test_core.py | 48 ++++++++++++++++----------------- 1 file changed, 23 insertions(+), 25 deletions(-) diff --git a/Lib/test/test_zstd/test_core.py b/Lib/test/test_zstd/test_core.py index edb5095bab4ba9..5e6b3267ad776a 100644 --- a/Lib/test/test_zstd/test_core.py +++ b/Lib/test/test_zstd/test_core.py @@ -65,14 +65,12 @@ TRAINED_DICT = None -KB = 1024 -MB = 1024*1024 def setUpModule(): # uncompressed size 130KB, more than a zstd block. # with a frame epilogue, 4 bytes checksum. global DAT_130K_D - DAT_130K_D = bytes([random.randint(0, 127) for _ in range(130*1024)]) + DAT_130K_D = bytes([random.randint(0, 127) for _ in range(130*_1K)]) global DAT_130K_C DAT_130K_C = compress(DAT_130K_D, options={CParameter.checksumFlag:1}) @@ -84,15 +82,15 @@ def setUpModule(): COMPRESSED_DAT = compress(DECOMPRESSED_DAT) global DECOMPRESSED_100_PLUS_32KB - DECOMPRESSED_100_PLUS_32KB = b'a' * (100 + 32*1024) + DECOMPRESSED_100_PLUS_32KB = b'a' * (100 + 32*_1K) global COMPRESSED_100_PLUS_32KB COMPRESSED_100_PLUS_32KB = compress(DECOMPRESSED_100_PLUS_32KB) global SKIPPABLE_FRAME SKIPPABLE_FRAME = (0x184D2A50).to_bytes(4, byteorder='little') + \ - (32*1024).to_bytes(4, byteorder='little') + \ - b'a' * (32*1024) + (32*_1K).to_bytes(4, byteorder='little') + \ + b'a' * (32*_1K) global THIS_FILE_BYTES, THIS_FILE_STR with builtins.open(os.path.abspath(__file__), 'rb') as f: @@ -122,8 +120,8 @@ def setUpModule(): assert len(SAMPLES) > 10 global TRAINED_DICT - TRAINED_DICT = train_dict(SAMPLES, 3*1024) - assert len(TRAINED_DICT.dict_content) <= 3*1024 + TRAINED_DICT = train_dict(SAMPLES, 3*_1K) + assert len(TRAINED_DICT.dict_content) <= 3*_1K class FunctionsTestCase(unittest.TestCase): @@ -157,7 +155,7 @@ def test_roundtrip_level(self): def test_get_frame_info(self): # no dict info = get_frame_info(COMPRESSED_100_PLUS_32KB[:20]) - self.assertEqual(info.decompressed_size, 32 * 1024 + 100) + self.assertEqual(info.decompressed_size, 32 * _1K + 100) self.assertEqual(info.dictionary_id, 0) # use dict @@ -675,7 +673,7 @@ def setUpClass(cls): cls.FRAME_42_60 = cls.FRAME_42 + cls.FRAME_60 cls.DECOMPRESSED_42_60 = cls.DECOMPRESSED_42 + cls.DECOMPRESSED_60 - cls._130_1K = 130*1024 + cls._130_1K = 130*_1K c = ZstdCompressor() cls.UNKNOWN_FRAME_42 = c.compress(cls.DECOMPRESSED_42) + c.flush() @@ -686,7 +684,7 @@ def setUpClass(cls): def test_function_decompress(self): - self.assertEqual(len(decompress(COMPRESSED_100_PLUS_32KB)), 100+32*1024) + self.assertEqual(len(decompress(COMPRESSED_100_PLUS_32KB)), 100+32*_1K) # 1 frame self.assertEqual(decompress(self.FRAME_42), self.DECOMPRESSED_42) @@ -875,11 +873,11 @@ def test_decompressor_1(self): self.assertEqual(d.unused_data, self.TRAIL) # twice # 1 frame, 32_1K - temp = compress(b'a'*(32*1024)) + temp = compress(b'a'*(32*_1K)) d = ZstdDecompressor() - dat = d.decompress(temp, 32*1024) + dat = d.decompress(temp, 32*_1K) - self.assertEqual(dat, b'a'*(32*1024)) + self.assertEqual(dat, b'a'*(32*_1K)) self.assertTrue(d.eof) self.assertFalse(d.needs_input) self.assertEqual(d.unused_data, b'') @@ -899,7 +897,7 @@ def test_decompressor_1(self): dat = d.decompress(b'') # 32_1K - self.assertEqual(len(dat), 32*1024) + self.assertEqual(len(dat), 32*_1K) self.assertTrue(d.eof) self.assertFalse(d.needs_input) self.assertEqual(d.unused_data, self.TRAIL) @@ -1117,7 +1115,7 @@ def test_train_dict(self): self.assertEqual(sample, dat2) def test_finalize_dict(self): - DICT_SIZE2 = 200*1024 + DICT_SIZE2 = 200*_1K C_LEVEL = 6 try: @@ -1233,20 +1231,20 @@ def _nbytes(dat): # wrong size list with self.assertRaisesRegex(ValueError, "The samples size list doesn't match the concatenation's size"): - _zstd._train_dict(concatenation, tuple(wrong_size_lst), 100*1024) + _zstd._train_dict(concatenation, tuple(wrong_size_lst), 100*_1K) # correct size list - _zstd._train_dict(concatenation, tuple(correct_size_lst), 3*1024) + _zstd._train_dict(concatenation, tuple(correct_size_lst), 3*_1K) # wrong size list with self.assertRaisesRegex(ValueError, "The samples size list doesn't match the concatenation's size"): _zstd._finalize_dict(TRAINED_DICT.dict_content, - concatenation, tuple(wrong_size_lst), 300*1024, 5) + concatenation, tuple(wrong_size_lst), 300*_1K, 5) # correct size list _zstd._finalize_dict(TRAINED_DICT.dict_content, - concatenation, tuple(correct_size_lst), 300*1024, 5) + concatenation, tuple(correct_size_lst), 300*_1K, 5) def test_as_prefix(self): # V1 @@ -1659,7 +1657,7 @@ def test_read_truncated(self): with ZstdFile(io.BytesIO(truncated)) as f: # this is an important test, make sure it doesn't raise EOFError. - self.assertEqual(f.read(130*1024), DAT_130K_D) + self.assertEqual(f.read(130*_1K), DAT_130K_D) with self.assertRaises(EOFError): f.read(1) @@ -1789,7 +1787,7 @@ def test_iterator(self): self.assertListEqual(f.readlines(), lines) def test_decompress_limited(self): - _ZSTD_DStreamInSize = 128*1024 + 3 + _ZSTD_DStreamInSize = 128*_1K + 3 bomb = compress(b'\0' * int(2e6), level=10) self.assertLess(len(bomb), _ZSTD_DStreamInSize) @@ -1798,7 +1796,7 @@ def test_decompress_limited(self): self.assertEqual(decomp.read(1), b'\0') # BufferedReader uses 128 KiB buffer in __init__.py - max_decomp = 128*1024 + max_decomp = 128*_1K self.assertLessEqual(decomp._buffer.raw.tell(), max_decomp, "Excessive amount of data was decompressed") @@ -1913,8 +1911,8 @@ def comp(data): comp = ZstdCompressor() return comp.compress(data) + comp.flush() - part1 = THIS_FILE_BYTES[:1024] - part2 = THIS_FILE_BYTES[1024:1536] + part1 = THIS_FILE_BYTES[:_1K] + part2 = THIS_FILE_BYTES[_1K:1536] part3 = THIS_FILE_BYTES[1536:] expected = b"".join(comp(x) for x in (part1, part2, part3)) with io.BytesIO() as dst: From 99653d201bb503d288805028c0a66341f06e079e Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Sun, 4 May 2025 11:00:13 -0700 Subject: [PATCH 15/55] Move compression.zstd.zstdfile to compression.zstd._zstdfile This makes it private. --- Lib/compression/zstd/__init__.py | 4 ++-- Lib/compression/zstd/{zstdfile.py => _zstdfile.py} | 0 2 files changed, 2 insertions(+), 2 deletions(-) rename Lib/compression/zstd/{zstdfile.py => _zstdfile.py} (100%) diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py index 5d721cd519e993..23eea37df14148 100644 --- a/Lib/compression/zstd/__init__.py +++ b/Lib/compression/zstd/__init__.py @@ -15,7 +15,7 @@ "train_dict", "zstd_support_multithread", - # compression.zstd.zstdfile + # compression.zstd._zstdfile "open", "ZstdFile", @@ -31,7 +31,7 @@ import enum -from compression.zstd.zstdfile import ZstdFile, open +from compression.zstd._zstdfile import ZstdFile, open from _zstd import * import _zstd diff --git a/Lib/compression/zstd/zstdfile.py b/Lib/compression/zstd/_zstdfile.py similarity index 100% rename from Lib/compression/zstd/zstdfile.py rename to Lib/compression/zstd/_zstdfile.py From e403a2582d2d56c790e468bef8a12c62b8903d99 Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Sun, 4 May 2025 11:16:05 -0700 Subject: [PATCH 16/55] Change compressLevel_values to COMPRESSION_LEVEL_DEFAULT --- Lib/compression/zstd/__init__.py | 16 +--------------- Lib/test/test_zstd/test_core.py | 25 +++++++++++++------------ 2 files changed, 14 insertions(+), 27 deletions(-) diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py index 23eea37df14148..ca287ea3f81057 100644 --- a/Lib/compression/zstd/__init__.py +++ b/Lib/compression/zstd/__init__.py @@ -43,22 +43,8 @@ _finalize_dict = _zstd._finalize_dict -class _CLValues: - __slots__ = 'default', 'min', 'max' +COMPRESSION_LEVEL_DEFAULT = _zstd._compressionLevel_values[0] - def __init__(self, default, min, max): - super().__setattr__('default', default) - super().__setattr__('min', min) - super().__setattr__('max', max) - - def __repr__(self): - return (f'compression_level_values(default={self.default}, ' - f'min={self.min}, max={self.max})') - - def __setattr__(self, name, _): - raise AttributeError(f"can't set attribute {name!r}") - -compressionLevel_values = _CLValues(*_zstd._compressionLevel_values) class FrameInfo: """Information about a Zstandard frame.""" diff --git a/Lib/test/test_zstd/test_core.py b/Lib/test/test_zstd/test_core.py index 5e6b3267ad776a..9b9ae5251786ad 100644 --- a/Lib/test/test_zstd/test_core.py +++ b/Lib/test/test_zstd/test_core.py @@ -28,7 +28,7 @@ ZstdError, zstd_version, zstd_version_info, - compressionLevel_values, + COMPRESSION_LEVEL_DEFAULT, get_frame_info, get_frame_size, finalize_dict, @@ -131,10 +131,11 @@ def test_version(self): self.assertEqual(s, zstd_version) def test_compressionLevel_values(self): - self.assertIs(type(compressionLevel_values.default), int) - self.assertIs(type(compressionLevel_values.min), int) - self.assertIs(type(compressionLevel_values.max), int) - self.assertLess(compressionLevel_values.min, compressionLevel_values.max) + min, max = CParameter.compressionLevel.bounds() + self.assertIs(type(COMPRESSION_LEVEL_DEFAULT), int) + self.assertIs(type(min), int) + self.assertIs(type(max), int) + self.assertLess(min, max) def test_roundtrip_default(self): raw_dat = THIS_FILE_BYTES[: len(THIS_FILE_BYTES) // 6] @@ -144,10 +145,9 @@ def test_roundtrip_default(self): def test_roundtrip_level(self): raw_dat = THIS_FILE_BYTES[: len(THIS_FILE_BYTES) // 6] - minv = compressionLevel_values.min - maxv = compressionLevel_values.max + level_min, level_max = CParameter.compressionLevel.bounds() - for level in range(max(-20, minv), maxv + 1): + for level in range(max(-20, level_min), level_max + 1): dat1 = compress(raw_dat, level) dat2 = decompress(dat1) self.assertEqual(dat2, raw_dat) @@ -259,11 +259,12 @@ def test_compress_parameters(self): self.assertRaises(ValueError, ZstdCompressor, d1) # clamp compressionLevel - compress(b'', compressionLevel_values.max+1) - compress(b'', compressionLevel_values.min-1) + level_min, level_max = CParameter.compressionLevel.bounds() + compress(b'', level_max+1) + compress(b'', level_min-1) - compress(b'', {CParameter.compressionLevel:compressionLevel_values.max+1}) - compress(b'', {CParameter.compressionLevel:compressionLevel_values.min-1}) + compress(b'', {CParameter.compressionLevel:level_max+1}) + compress(b'', {CParameter.compressionLevel:level_min-1}) # zstd lib doesn't support MT compression if not zstd_support_multithread: From 63625bc5b246f96760fe73673656d266d5847941 Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Sun, 4 May 2025 11:16:59 -0700 Subject: [PATCH 17/55] Fix tests for change in error message --- Lib/test/test_zstd/test_core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Lib/test/test_zstd/test_core.py b/Lib/test/test_zstd/test_core.py index 9b9ae5251786ad..a57fc6a41cd576 100644 --- a/Lib/test/test_zstd/test_core.py +++ b/Lib/test/test_zstd/test_core.py @@ -1231,7 +1231,7 @@ def _nbytes(dat): # wrong size list with self.assertRaisesRegex(ValueError, - "The samples size list doesn't match the concatenation's size"): + "The samples size tuple doesn't match the concatenation's size"): _zstd._train_dict(concatenation, tuple(wrong_size_lst), 100*_1K) # correct size list @@ -1239,7 +1239,7 @@ def _nbytes(dat): # wrong size list with self.assertRaisesRegex(ValueError, - "The samples size list doesn't match the concatenation's size"): + "The samples size tuple doesn't match the concatenation's size"): _zstd._finalize_dict(TRAINED_DICT.dict_content, concatenation, tuple(wrong_size_lst), 300*_1K, 5) From 1ea4b9a28dd8a2cf469089b1e0a0cc3b8595c012 Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Sun, 4 May 2025 11:46:06 -0700 Subject: [PATCH 18/55] Make parameter names snake case --- Lib/compression/zstd/__init__.py | 40 +++++------ Lib/test/test_zstd/test_core.py | 118 +++++++++++++++---------------- Modules/_zstd/_zstdmodule.c | 38 +++++----- 3 files changed, 98 insertions(+), 98 deletions(-) diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py index ca287ea3f81057..a1046fd0d77b29 100644 --- a/Lib/compression/zstd/__init__.py +++ b/Lib/compression/zstd/__init__.py @@ -182,28 +182,28 @@ def decompress(data, zstd_dict=None, options=None): class CParameter(enum.IntEnum): """Compression parameters""" - compressionLevel = _zstd._ZSTD_c_compressionLevel - windowLog = _zstd._ZSTD_c_windowLog - hashLog = _zstd._ZSTD_c_hashLog - chainLog = _zstd._ZSTD_c_chainLog - searchLog = _zstd._ZSTD_c_searchLog - minMatch = _zstd._ZSTD_c_minMatch - targetLength = _zstd._ZSTD_c_targetLength + compression_level = _zstd._ZSTD_c_compressionLevel + window_log = _zstd._ZSTD_c_windowLog + hash_log = _zstd._ZSTD_c_hashLog + chain_log = _zstd._ZSTD_c_chainLog + search_log = _zstd._ZSTD_c_searchLog + min_match = _zstd._ZSTD_c_minMatch + target_length = _zstd._ZSTD_c_targetLength strategy = _zstd._ZSTD_c_strategy - enableLongDistanceMatching = _zstd._ZSTD_c_enableLongDistanceMatching - ldmHashLog = _zstd._ZSTD_c_ldmHashLog - ldmMinMatch = _zstd._ZSTD_c_ldmMinMatch - ldmBucketSizeLog = _zstd._ZSTD_c_ldmBucketSizeLog - ldmHashRateLog = _zstd._ZSTD_c_ldmHashRateLog + enable_long_distance_matching = _zstd._ZSTD_c_enableLongDistanceMatching + ldm_hash_log = _zstd._ZSTD_c_ldmHashLog + ldm_min_match = _zstd._ZSTD_c_ldmMinMatch + ldm_bucket_size_log = _zstd._ZSTD_c_ldmBucketSizeLog + ldm_hash_rate_log = _zstd._ZSTD_c_ldmHashRateLog - contentSizeFlag = _zstd._ZSTD_c_contentSizeFlag - checksumFlag = _zstd._ZSTD_c_checksumFlag - dictIDFlag = _zstd._ZSTD_c_dictIDFlag + content_size_flag = _zstd._ZSTD_c_contentSizeFlag + checksum_flag = _zstd._ZSTD_c_checksumFlag + dict_id_flag = _zstd._ZSTD_c_dictIDFlag - nbWorkers = _zstd._ZSTD_c_nbWorkers - jobSize = _zstd._ZSTD_c_jobSize - overlapLog = _zstd._ZSTD_c_overlapLog + nb_workers = _zstd._ZSTD_c_nbWorkers + job_size = _zstd._ZSTD_c_jobSize + overlap_log = _zstd._ZSTD_c_overlapLog def bounds(self): """Return lower and upper bounds of a compression parameter, both inclusive.""" @@ -213,7 +213,7 @@ def bounds(self): class DParameter(enum.IntEnum): """Decompression parameters""" - windowLogMax = _zstd._ZSTD_d_windowLogMax + window_log_max = _zstd._ZSTD_d_windowLogMax def bounds(self): """Return lower and upper bounds of a decompression parameter, both inclusive.""" @@ -241,4 +241,4 @@ class Strategy(enum.IntEnum): # Set CParameter/DParameter types for validity check _zstd._set_parameter_types(CParameter, DParameter) -zstd_support_multithread = CParameter.nbWorkers.bounds() != (0, 0) +zstd_support_multithread = CParameter.nb_workers.bounds() != (0, 0) diff --git a/Lib/test/test_zstd/test_core.py b/Lib/test/test_zstd/test_core.py index a57fc6a41cd576..eca6dd4a2dd43c 100644 --- a/Lib/test/test_zstd/test_core.py +++ b/Lib/test/test_zstd/test_core.py @@ -73,7 +73,7 @@ def setUpModule(): DAT_130K_D = bytes([random.randint(0, 127) for _ in range(130*_1K)]) global DAT_130K_C - DAT_130K_C = compress(DAT_130K_D, options={CParameter.checksumFlag:1}) + DAT_130K_C = compress(DAT_130K_D, options={CParameter.checksum_flag:1}) global DECOMPRESSED_DAT DECOMPRESSED_DAT = b'abcdefg123456' * 1000 @@ -131,7 +131,7 @@ def test_version(self): self.assertEqual(s, zstd_version) def test_compressionLevel_values(self): - min, max = CParameter.compressionLevel.bounds() + min, max = CParameter.compression_level.bounds() self.assertIs(type(COMPRESSION_LEVEL_DEFAULT), int) self.assertIs(type(min), int) self.assertIs(type(max), int) @@ -145,7 +145,7 @@ def test_roundtrip_default(self): def test_roundtrip_level(self): raw_dat = THIS_FILE_BYTES[: len(THIS_FILE_BYTES) // 6] - level_min, level_max = CParameter.compressionLevel.bounds() + level_min, level_max = CParameter.compression_level.bounds() for level in range(max(-20, level_min), level_max + 1): dat1 = compress(raw_dat, level) @@ -201,7 +201,7 @@ def test_simple_compress_bad_args(self): ZstdCompressor(options={2**31: 100}) with self.assertRaises(ZstdError): - ZstdCompressor(options={CParameter.windowLog: 100}) + ZstdCompressor(options={CParameter.window_log: 100}) with self.assertRaises(ZstdError): ZstdCompressor(options={3333: 100}) @@ -227,65 +227,65 @@ def test_simple_compress_bad_args(self): zc.flush(zc.FLUSH_FRAME) def test_compress_parameters(self): - d = {CParameter.compressionLevel : 10, - - CParameter.windowLog : 12, - CParameter.hashLog : 10, - CParameter.chainLog : 12, - CParameter.searchLog : 12, - CParameter.minMatch : 4, - CParameter.targetLength : 12, + d = {CParameter.compression_level : 10, + + CParameter.window_log : 12, + CParameter.hash_log : 10, + CParameter.chain_log : 12, + CParameter.search_log : 12, + CParameter.min_match : 4, + CParameter.target_length : 12, CParameter.strategy : Strategy.lazy, - CParameter.enableLongDistanceMatching : 1, - CParameter.ldmHashLog : 12, - CParameter.ldmMinMatch : 11, - CParameter.ldmBucketSizeLog : 5, - CParameter.ldmHashRateLog : 12, + CParameter.enable_long_distance_matching : 1, + CParameter.ldm_hash_log : 12, + CParameter.ldm_min_match : 11, + CParameter.ldm_bucket_size_log : 5, + CParameter.ldm_hash_rate_log : 12, - CParameter.contentSizeFlag : 1, - CParameter.checksumFlag : 1, - CParameter.dictIDFlag : 0, + CParameter.content_size_flag : 1, + CParameter.checksum_flag : 1, + CParameter.dict_id_flag : 0, - CParameter.nbWorkers : 2 if zstd_support_multithread else 0, - CParameter.jobSize : 5*_1M if zstd_support_multithread else 0, - CParameter.overlapLog : 9 if zstd_support_multithread else 0, + CParameter.nb_workers : 2 if zstd_support_multithread else 0, + CParameter.job_size : 5*_1M if zstd_support_multithread else 0, + CParameter.overlap_log : 9 if zstd_support_multithread else 0, } ZstdCompressor(options=d) # larger than signed int, ValueError d1 = d.copy() - d1[CParameter.ldmBucketSizeLog] = 2**31 + d1[CParameter.ldm_bucket_size_log] = 2**31 self.assertRaises(ValueError, ZstdCompressor, d1) # clamp compressionLevel - level_min, level_max = CParameter.compressionLevel.bounds() + level_min, level_max = CParameter.compression_level.bounds() compress(b'', level_max+1) compress(b'', level_min-1) - compress(b'', {CParameter.compressionLevel:level_max+1}) - compress(b'', {CParameter.compressionLevel:level_min-1}) + compress(b'', {CParameter.compression_level:level_max+1}) + compress(b'', {CParameter.compression_level:level_min-1}) # zstd lib doesn't support MT compression if not zstd_support_multithread: with self.assertRaises(ZstdError): - ZstdCompressor({CParameter.nbWorkers:4}) + ZstdCompressor({CParameter.nb_workers:4}) with self.assertRaises(ZstdError): - ZstdCompressor({CParameter.jobSize:4}) + ZstdCompressor({CParameter.job_size:4}) with self.assertRaises(ZstdError): - ZstdCompressor({CParameter.overlapLog:4}) + ZstdCompressor({CParameter.overlap_log:4}) # out of bounds error msg - option = {CParameter.windowLog:100} + option = {CParameter.window_log:100} with self.assertRaisesRegex(ZstdError, - (r'Error when setting zstd compression parameter "windowLog", ' + (r'Error when setting zstd compression parameter "window_log", ' r'it should \d+ <= value <= \d+, provided value is 100\. ' r'\(zstd v\d\.\d\.\d, (?:32|64)-bit build\)')): compress(b'', option) def test_unknown_compression_parameter(self): KEY = 100001234 - option = {CParameter.compressionLevel: 10, + option = {CParameter.compression_level: 10, KEY: 200000000} pattern = r'Zstd compression parameter.*?"unknown parameter \(key %d\)"' \ % KEY @@ -298,8 +298,8 @@ def test_zstd_multithread_compress(self): size = 40*_1M b = THIS_FILE_BYTES * (size // len(THIS_FILE_BYTES)) - options = {CParameter.compressionLevel : 4, - CParameter.nbWorkers : 2} + options = {CParameter.compression_level : 4, + CParameter.nb_workers : 2} # compress() dat1 = compress(b, options=options) @@ -388,7 +388,7 @@ def test_simple_decompress_bad_args(self): ZstdDecompressor(options={2**31 : 100}) with self.assertRaises(ZstdError): - ZstdDecompressor(options={DParameter.windowLogMax:100}) + ZstdDecompressor(options={DParameter.window_log_max:100}) with self.assertRaises(ZstdError): ZstdDecompressor(options={3333 : 100}) @@ -400,25 +400,25 @@ def test_simple_decompress_bad_args(self): lzd.decompress(empty) def test_decompress_parameters(self): - d = {DParameter.windowLogMax : 15} + d = {DParameter.window_log_max : 15} ZstdDecompressor(options=d) # larger than signed int, ValueError d1 = d.copy() - d1[DParameter.windowLogMax] = 2**31 + d1[DParameter.window_log_max] = 2**31 self.assertRaises(ValueError, ZstdDecompressor, None, d1) # out of bounds error msg - options = {DParameter.windowLogMax:100} + options = {DParameter.window_log_max:100} with self.assertRaisesRegex(ZstdError, - (r'Error when setting zstd decompression parameter "windowLogMax", ' + (r'Error when setting zstd decompression parameter "window_log_max", ' r'it should \d+ <= value <= \d+, provided value is 100\. ' r'\(zstd v\d\.\d\.\d, (?:32|64)-bit build\)')): decompress(b'', options=options) def test_unknown_decompression_parameter(self): KEY = 100001234 - options = {DParameter.windowLogMax: DParameter.windowLogMax.bounds()[1], + options = {DParameter.window_log_max: DParameter.window_log_max.bounds()[1], KEY: 200000000} pattern = r'Zstd decompression parameter.*?"unknown parameter \(key %d\)"' \ % KEY @@ -517,7 +517,7 @@ def test_decompressor_arg(self): ZstdDecompressor() ZstdDecompressor(zd, {}) - ZstdDecompressor(zstd_dict=zd, options={DParameter.windowLogMax:25}) + ZstdDecompressor(zstd_dict=zd, options={DParameter.window_log_max:25}) def test_decompressor_1(self): # empty @@ -662,7 +662,7 @@ class DecompressorFlagsTestCase(unittest.TestCase): @classmethod def setUpClass(cls): - options = {CParameter.checksumFlag:1} + options = {CParameter.checksum_flag:1} c = ZstdCompressor(options) cls.DECOMPRESSED_42 = b'a'*42 @@ -1294,9 +1294,9 @@ def test_as_digested_dict(self): zd.as_undigested_dict = b'1234' def test_advanced_compression_parameters(self): - options = {CParameter.compressionLevel: 6, - CParameter.windowLog: 20, - CParameter.enableLongDistanceMatching: 1} + options = {CParameter.compression_level: 6, + CParameter.window_log: 20, + CParameter.enable_long_distance_matching: 1} # automatically select dat = compress(SAMPLES[0], options=options, zstd_dict=TRAINED_DICT) @@ -1327,14 +1327,14 @@ def test_init(self): with ZstdFile(io.BytesIO(), "w", level=12) as f: pass - with ZstdFile(io.BytesIO(), "w", options={CParameter.checksumFlag:1}) as f: + with ZstdFile(io.BytesIO(), "w", options={CParameter.checksum_flag:1}) as f: pass with ZstdFile(io.BytesIO(), "w", options={}) as f: pass with ZstdFile(io.BytesIO(), "w", level=20, zstd_dict=TRAINED_DICT) as f: pass - with ZstdFile(io.BytesIO(), "r", options={DParameter.windowLogMax:25}) as f: + with ZstdFile(io.BytesIO(), "r", options={DParameter.window_log_max:25}) as f: pass with ZstdFile(io.BytesIO(), "r", options={}, zstd_dict=TRAINED_DICT) as f: pass @@ -1421,9 +1421,9 @@ def test_init_bad_mode(self): ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rw") with self.assertRaisesRegex(TypeError, r"NOT be CParameter"): - ZstdFile(io.BytesIO(), 'rb', options={CParameter.compressionLevel:5}) + ZstdFile(io.BytesIO(), 'rb', options={CParameter.compression_level:5}) with self.assertRaisesRegex(TypeError, r"NOT be DParameter"): - ZstdFile(io.BytesIO(), 'wb', options={DParameter.windowLogMax:21}) + ZstdFile(io.BytesIO(), 'wb', options={DParameter.window_log_max:21}) with self.assertRaises(TypeError): ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r", level=12) @@ -1435,14 +1435,14 @@ def test_init_bad_check(self): with self.assertRaises(ZstdError): ZstdFile(io.BytesIO(), "w", options={999:9999}) with self.assertRaises(ZstdError): - ZstdFile(io.BytesIO(), "w", options={CParameter.windowLog:99}) + ZstdFile(io.BytesIO(), "w", options={CParameter.window_log:99}) with self.assertRaises(TypeError): ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r", options=33) with self.assertRaises(ValueError): ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), - options={DParameter.windowLogMax:2**31}) + options={DParameter.window_log_max:2**31}) with self.assertRaises(ZstdError): ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), @@ -1605,7 +1605,7 @@ def test_read_0(self): self.assertEqual(f.read(0), b"") self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB) with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), - options={DParameter.windowLogMax:20}) as f: + options={DParameter.window_log_max:20}) as f: self.assertEqual(f.read(0), b"") # empty file @@ -1820,16 +1820,16 @@ def test_write(self): self.assertEqual(dst.getvalue(), expected) with io.BytesIO() as dst: - with ZstdFile(dst, "w", options={CParameter.checksumFlag:1}) as f: + with ZstdFile(dst, "w", options={CParameter.checksum_flag:1}) as f: f.write(raw_data) - comp = ZstdCompressor({CParameter.checksumFlag:1}) + comp = ZstdCompressor({CParameter.checksum_flag:1}) expected = comp.compress(raw_data) + comp.flush() self.assertEqual(dst.getvalue(), expected) with io.BytesIO() as dst: - options = {CParameter.compressionLevel:-5, - CParameter.checksumFlag:1} + options = {CParameter.compression_level:-5, + CParameter.checksum_flag:1} with ZstdFile(dst, "w", options=options) as f: f.write(raw_data) @@ -2288,11 +2288,11 @@ def test_bad_params(self): os.remove(TESTFN) def test_option(self): - options = {DParameter.windowLogMax:25} + options = {DParameter.window_log_max:25} with open(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rb", options=options) as f: self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB) - options = {CParameter.compressionLevel:12} + options = {CParameter.compression_level:12} with io.BytesIO() as bio: with open(bio, "wb", options=options) as f: f.write(DECOMPRESSED_100_PLUS_32KB) diff --git a/Modules/_zstd/_zstdmodule.c b/Modules/_zstd/_zstdmodule.c index 994f151f1f449b..1c122466f671ae 100644 --- a/Modules/_zstd/_zstdmodule.c +++ b/Modules/_zstd/_zstdmodule.c @@ -74,33 +74,33 @@ typedef struct { static const ParameterInfo cp_list[] = { - {ZSTD_c_compressionLevel, "compressionLevel"}, - {ZSTD_c_windowLog, "windowLog"}, - {ZSTD_c_hashLog, "hashLog"}, - {ZSTD_c_chainLog, "chainLog"}, - {ZSTD_c_searchLog, "searchLog"}, - {ZSTD_c_minMatch, "minMatch"}, - {ZSTD_c_targetLength, "targetLength"}, + {ZSTD_c_compressionLevel, "compression_level"}, + {ZSTD_c_windowLog, "window_log"}, + {ZSTD_c_hashLog, "hash_log"}, + {ZSTD_c_chainLog, "chain_log"}, + {ZSTD_c_searchLog, "search_log"}, + {ZSTD_c_minMatch, "min_match"}, + {ZSTD_c_targetLength, "target_length"}, {ZSTD_c_strategy, "strategy"}, - {ZSTD_c_enableLongDistanceMatching, "enableLongDistanceMatching"}, - {ZSTD_c_ldmHashLog, "ldmHashLog"}, - {ZSTD_c_ldmMinMatch, "ldmMinMatch"}, - {ZSTD_c_ldmBucketSizeLog, "ldmBucketSizeLog"}, - {ZSTD_c_ldmHashRateLog, "ldmHashRateLog"}, + {ZSTD_c_enableLongDistanceMatching, "enable_long_distance_matching"}, + {ZSTD_c_ldmHashLog, "ldm_hash_log"}, + {ZSTD_c_ldmMinMatch, "ldm_min_match"}, + {ZSTD_c_ldmBucketSizeLog, "ldm_bucket_size_log"}, + {ZSTD_c_ldmHashRateLog, "ldm_hash_rate_log"}, - {ZSTD_c_contentSizeFlag, "contentSizeFlag"}, - {ZSTD_c_checksumFlag, "checksumFlag"}, - {ZSTD_c_dictIDFlag, "dictIDFlag"}, + {ZSTD_c_contentSizeFlag, "content_size_flag"}, + {ZSTD_c_checksumFlag, "checksum_flag"}, + {ZSTD_c_dictIDFlag, "dict_id_flag"}, - {ZSTD_c_nbWorkers, "nbWorkers"}, - {ZSTD_c_jobSize, "jobSize"}, - {ZSTD_c_overlapLog, "overlapLog"} + {ZSTD_c_nbWorkers, "nb_workers"}, + {ZSTD_c_jobSize, "job_size"}, + {ZSTD_c_overlapLog, "overlap_log"} }; static const ParameterInfo dp_list[] = { - {ZSTD_d_windowLogMax, "windowLogMax"} + {ZSTD_d_windowLogMax, "window_log_max"} }; void From ad05da8e931e4afcf3d46684f13e0200b4d823d9 Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Sun, 4 May 2025 11:48:30 -0700 Subject: [PATCH 19/55] Replace compressionLevel_values re-export with COMPRESSION_LEVEL_DEFAULT --- Lib/compression/zstd/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py index a1046fd0d77b29..1697b7664a9381 100644 --- a/Lib/compression/zstd/__init__.py +++ b/Lib/compression/zstd/__init__.py @@ -4,7 +4,7 @@ __all__ = ( # compression.zstd - "compressionLevel_values", + "COMPRESSION_LEVEL_DEFAULT", "compress", "CParameter", "decompress", From e82e23de5e9acaaeb20673e102313fceb5acb798 Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Sun, 4 May 2025 12:22:27 -0700 Subject: [PATCH 20/55] Move zstd_support_multithread to tests and rename --- Lib/compression/zstd/__init__.py | 2 -- Lib/test/test_zstd/test_core.py | 14 ++++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py index 1697b7664a9381..514399953fb288 100644 --- a/Lib/compression/zstd/__init__.py +++ b/Lib/compression/zstd/__init__.py @@ -240,5 +240,3 @@ class Strategy(enum.IntEnum): # Set CParameter/DParameter types for validity check _zstd._set_parameter_types(CParameter, DParameter) - -zstd_support_multithread = CParameter.nb_workers.bounds() != (0, 0) diff --git a/Lib/test/test_zstd/test_core.py b/Lib/test/test_zstd/test_core.py index eca6dd4a2dd43c..a1439af3903b7d 100644 --- a/Lib/test/test_zstd/test_core.py +++ b/Lib/test/test_zstd/test_core.py @@ -37,7 +37,6 @@ DParameter, Strategy, ZstdFile, - zstd_support_multithread, ) _1K = 1024 @@ -65,8 +64,11 @@ TRAINED_DICT = None +SUPPORT_MULTITHREADING = False def setUpModule(): + global SUPPORT_MULTITHREADING + SUPPORT_MULTITHREADING = CParameter.nb_workers.bounds() != (0, 0) # uncompressed size 130KB, more than a zstd block. # with a frame epilogue, 4 bytes checksum. global DAT_130K_D @@ -247,9 +249,9 @@ def test_compress_parameters(self): CParameter.checksum_flag : 1, CParameter.dict_id_flag : 0, - CParameter.nb_workers : 2 if zstd_support_multithread else 0, - CParameter.job_size : 5*_1M if zstd_support_multithread else 0, - CParameter.overlap_log : 9 if zstd_support_multithread else 0, + CParameter.nb_workers : 2 if SUPPORT_MULTITHREADING else 0, + CParameter.job_size : 5*_1M if SUPPORT_MULTITHREADING else 0, + CParameter.overlap_log : 9 if SUPPORT_MULTITHREADING else 0, } ZstdCompressor(options=d) @@ -267,7 +269,7 @@ def test_compress_parameters(self): compress(b'', {CParameter.compression_level:level_min-1}) # zstd lib doesn't support MT compression - if not zstd_support_multithread: + if not SUPPORT_MULTITHREADING: with self.assertRaises(ZstdError): ZstdCompressor({CParameter.nb_workers:4}) with self.assertRaises(ZstdError): @@ -292,7 +294,7 @@ def test_unknown_compression_parameter(self): with self.assertRaisesRegex(ZstdError, pattern): ZstdCompressor(option) - @unittest.skipIf(True,#not zstd_support_multithread, + @unittest.skipIf(not SUPPORT_MULTITHREADING, "zstd build doesn't support multi-threaded compression") def test_zstd_multithread_compress(self): size = 40*_1M From 7801b6b528402b3fd0c30cb7e114b54a1f059948 Mon Sep 17 00:00:00 2001 From: Emma Smith Date: Sun, 4 May 2025 12:26:34 -0700 Subject: [PATCH 21/55] Update module docstring for compression.zstd Co-authored-by: Gregory P. Smith --- Lib/compression/zstd/__init__.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py index 514399953fb288..ac8707e3a48f2a 100644 --- a/Lib/compression/zstd/__init__.py +++ b/Lib/compression/zstd/__init__.py @@ -1,6 +1,4 @@ -"""Python bindings to Zstandard (zstd) compression library, the API style is -similar to Python's bz2/lzma/zlib modules. -""" +"""Python bindings to the Zstandard (zstd) compression library (RFC-8878).""" __all__ = ( # compression.zstd From df5d827e41c0467a2649a8f11329ed2d660e960e Mon Sep 17 00:00:00 2001 From: Emma Smith Date: Sun, 4 May 2025 12:28:39 -0700 Subject: [PATCH 22/55] Clarify Strategy stability in docstring Co-authored-by: Adam Turner <9087854+AA-Turner@users.noreply.github.com> --- Lib/compression/zstd/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py index ac8707e3a48f2a..4cb394d5f88fae 100644 --- a/Lib/compression/zstd/__init__.py +++ b/Lib/compression/zstd/__init__.py @@ -221,8 +221,9 @@ def bounds(self): class Strategy(enum.IntEnum): """Compression strategies, listed from fastest to strongest. - Note : new strategies _might_ be added in the future, only the order - (from fast to strong) is guaranteed. + Note that new strategies might be added in the future. + Only the order (from fast to strong) is guaranteed, + the numeric value might change. """ fast = _zstd._ZSTD_fast From 4ff48da0996830d337fbe0babff77f18c7d20548 Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Sun, 4 May 2025 12:44:29 -0700 Subject: [PATCH 23/55] Fix formatting in tarfile Co-authored-by: Tomas R. --- Lib/tarfile.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/Lib/tarfile.py b/Lib/tarfile.py index 7e86614be50896..ed7eee193e7981 100644 --- a/Lib/tarfile.py +++ b/Lib/tarfile.py @@ -1835,7 +1835,7 @@ def open(cls, name=None, mode="r", fileobj=None, bufsize=RECORDSIZE, **kwargs): 'w:gz' open for writing with gzip compression 'w:bz2' open for writing with bzip2 compression 'w:xz' open for writing with lzma compression - 'w:zst' open for writing with zstd compression + 'w:zst' open for writing with zstd compression 'x' or 'x:' create a tarfile exclusively without compression, raise an exception if the file is already created @@ -1853,12 +1853,12 @@ def open(cls, name=None, mode="r", fileobj=None, bufsize=RECORDSIZE, **kwargs): 'r|gz' open a gzip compressed stream of tar blocks 'r|bz2' open a bzip2 compressed stream of tar blocks 'r|xz' open an lzma compressed stream of tar blocks - 'r|zst' open a zstd compressed stream of tar blocks + 'r|zst' open a zstd compressed stream of tar blocks 'w|' open an uncompressed stream for writing 'w|gz' open a gzip compressed stream for writing 'w|bz2' open a bzip2 compressed stream for writing 'w|xz' open an lzma compressed stream for writing - 'w|zst' open a zstd compressed stream for writing + 'w|zst' open a zstd compressed stream for writing """ if not name and not fileobj: @@ -2026,7 +2026,7 @@ def xzopen(cls, name, mode="r", fileobj=None, preset=None, **kwargs): @classmethod def zstopen(cls, name, mode="r", fileobj=None, level=None, options=None, - zstd_dict=None, **kwargs): + zstd_dict=None, **kwargs): """Open zstd compressed tar archive name for reading or writing. Appending is not allowed. """ @@ -2064,8 +2064,8 @@ def zstopen(cls, name, mode="r", fileobj=None, level=None, options=None, "tar": "taropen", # uncompressed tar "gz": "gzopen", # gzip compressed tar "bz2": "bz2open", # bzip2 compressed tar - "xz": "xzopen", # lzma compressed tar - "zst": "zstopen" # zstd compressed tar + "xz": "xzopen", # lzma compressed tar + "zst": "zstopen" # zstd compressed tar } #-------------------------------------------------------------------------- From c68a896c39346fba8eb961312dc40126ac410bdc Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Sun, 4 May 2025 12:46:43 -0700 Subject: [PATCH 24/55] Remove zstd_support_multithread from __all__ --- Lib/compression/zstd/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py index 4cb394d5f88fae..0986f9329f2c02 100644 --- a/Lib/compression/zstd/__init__.py +++ b/Lib/compression/zstd/__init__.py @@ -11,7 +11,6 @@ "get_frame_info", "Strategy", "train_dict", - "zstd_support_multithread", # compression.zstd._zstdfile "open", From 326400db4a0e641c589ec4490cf5867b2084d9d7 Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Sun, 4 May 2025 12:52:56 -0700 Subject: [PATCH 25/55] Add test_name from upstream Co-authored-by: Rogdham --- Lib/test/test_zstd/test_core.py | 34 +++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/Lib/test/test_zstd/test_core.py b/Lib/test/test_zstd/test_core.py index a1439af3903b7d..9eea8058930b97 100644 --- a/Lib/test/test_zstd/test_core.py +++ b/Lib/test/test_zstd/test_core.py @@ -1542,6 +1542,40 @@ def read(self, size=-1): with self.assertRaisesRegex(AttributeError, r'fileno'): f.fileno() + def test_name(self): + # 1 + f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) + try: + with self.assertRaises(AttributeError): + f.name + finally: + f.close() + with self.assertRaises(ValueError): + f.name + + # 2 + with tempfile.NamedTemporaryFile(delete=False) as tmp_f: + filename = pathlib.Path(tmp_f.name) + + f = ZstdFile(filename) + try: + self.assertEqual(f.name, f._fp.name) + self.assertIsInstance(f.name, str) + finally: + f.close() + with self.assertRaises(ValueError): + f.name + + os.remove(filename) + + # 3, no .filename property + class C: + def read(self, size=-1): + return b'123' + with ZstdFile(C(), 'rb') as f: + with self.assertRaisesRegex(AttributeError, r'name'): + f.name + def test_seekable(self): f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) try: From 2c0c9a1b728e83435282e88ccee6c365e23e9387 Mon Sep 17 00:00:00 2001 From: Emma Smith Date: Sun, 4 May 2025 12:59:15 -0700 Subject: [PATCH 26/55] Don't close tarfile if there is a BaseException Co-authored-by: Tomas R. --- Lib/tarfile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Lib/tarfile.py b/Lib/tarfile.py index ed7eee193e7981..c0f5a609b9f42f 100644 --- a/Lib/tarfile.py +++ b/Lib/tarfile.py @@ -2053,7 +2053,7 @@ def zstopen(cls, name, mode="r", fileobj=None, level=None, options=None, if mode == 'r': raise ReadError("not a zstd file") from e raise - except: + except Exception: fileobj.close() raise t._extfileobj = False From 49f382127c814a2252be4914c740e1fc53daf086 Mon Sep 17 00:00:00 2001 From: Emma Smith Date: Sun, 4 May 2025 13:01:57 -0700 Subject: [PATCH 27/55] Use options kwarg in tests Co-authored-by: Rogdham <3994389+Rogdham@users.noreply.github.com> --- Lib/test/test_zstd/test_core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Lib/test/test_zstd/test_core.py b/Lib/test/test_zstd/test_core.py index 9eea8058930b97..5d3b7ce934228b 100644 --- a/Lib/test/test_zstd/test_core.py +++ b/Lib/test/test_zstd/test_core.py @@ -258,7 +258,7 @@ def test_compress_parameters(self): # larger than signed int, ValueError d1 = d.copy() d1[CParameter.ldm_bucket_size_log] = 2**31 - self.assertRaises(ValueError, ZstdCompressor, d1) + self.assertRaises(ValueError, ZstdCompressor, options=d1) # clamp compressionLevel level_min, level_max = CParameter.compression_level.bounds() @@ -283,7 +283,7 @@ def test_compress_parameters(self): (r'Error when setting zstd compression parameter "window_log", ' r'it should \d+ <= value <= \d+, provided value is 100\. ' r'\(zstd v\d\.\d\.\d, (?:32|64)-bit build\)')): - compress(b'', option) + compress(b'', options=option) def test_unknown_compression_parameter(self): KEY = 100001234 @@ -292,7 +292,7 @@ def test_unknown_compression_parameter(self): pattern = r'Zstd compression parameter.*?"unknown parameter \(key %d\)"' \ % KEY with self.assertRaisesRegex(ZstdError, pattern): - ZstdCompressor(option) + ZstdCompressor(options=option) @unittest.skipIf(not SUPPORT_MULTITHREADING, "zstd build doesn't support multi-threaded compression") @@ -665,7 +665,7 @@ class DecompressorFlagsTestCase(unittest.TestCase): @classmethod def setUpClass(cls): options = {CParameter.checksum_flag:1} - c = ZstdCompressor(options) + c = ZstdCompressor(options=options) cls.DECOMPRESSED_42 = b'a'*42 cls.FRAME_42 = c.compress(cls.DECOMPRESSED_42, c.FLUSH_FRAME) From 8ba6bdad44ba764bc3360731bc7ac5d4a550019b Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Sun, 4 May 2025 13:31:44 -0700 Subject: [PATCH 28/55] Use options kwarg in tests in more places These were not shown because the flag names changed. Co-authored-by: Rogdham <3994389+Rogdham@users.noreply.github.com> --- Lib/test/test_zstd/test_core.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/Lib/test/test_zstd/test_core.py b/Lib/test/test_zstd/test_core.py index 5d3b7ce934228b..15e34edf40c34a 100644 --- a/Lib/test/test_zstd/test_core.py +++ b/Lib/test/test_zstd/test_core.py @@ -265,17 +265,17 @@ def test_compress_parameters(self): compress(b'', level_max+1) compress(b'', level_min-1) - compress(b'', {CParameter.compression_level:level_max+1}) - compress(b'', {CParameter.compression_level:level_min-1}) + compress(b'', options={CParameter.compression_level:level_max+1}) + compress(b'', options={CParameter.compression_level:level_min-1}) # zstd lib doesn't support MT compression if not SUPPORT_MULTITHREADING: with self.assertRaises(ZstdError): - ZstdCompressor({CParameter.nb_workers:4}) + ZstdCompressor(options={CParameter.nb_workers:4}) with self.assertRaises(ZstdError): - ZstdCompressor({CParameter.job_size:4}) + ZstdCompressor(options={CParameter.job_size:4}) with self.assertRaises(ZstdError): - ZstdCompressor({CParameter.overlap_log:4}) + ZstdCompressor(options={CParameter.overlap_log:4}) # out of bounds error msg option = {CParameter.window_log:100} @@ -1859,7 +1859,7 @@ def test_write(self): with ZstdFile(dst, "w", options={CParameter.checksum_flag:1}) as f: f.write(raw_data) - comp = ZstdCompressor({CParameter.checksum_flag:1}) + comp = ZstdCompressor(options={CParameter.checksum_flag:1}) expected = comp.compress(raw_data) + comp.flush() self.assertEqual(dst.getvalue(), expected) From 129d5e6f0099b21e7c7f31e5047432d4d9f0018a Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Sun, 4 May 2025 13:36:00 -0700 Subject: [PATCH 29/55] Adopt suggestions by Tomas R. for _zstdfile Co-authored-by: Tomas R. --- Lib/compression/zstd/_zstdfile.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/Lib/compression/zstd/_zstdfile.py b/Lib/compression/zstd/_zstdfile.py index 1ca60fe5677454..0fd31254cee19f 100644 --- a/Lib/compression/zstd/_zstdfile.py +++ b/Lib/compression/zstd/_zstdfile.py @@ -72,21 +72,21 @@ def __init__( "options." ) ) - if level: + if level is not None: raise TypeError("level argument should only be passed when writing.") mode_code = _MODE_READ elif mode in ("w", "wb", "a", "ab", "x", "xb"): if not isinstance(level, (type(None), int)): - raise TypeError(("level argument should be an int object.")) + raise TypeError("level argument should be an int object.") if not isinstance(options, (type(None), dict)): - raise TypeError(("options argument should be an dict object.")) + raise TypeError("options argument should be an dict object.") mode_code = _MODE_WRITE self._compressor = ZstdCompressor( level=level, options=options, zstd_dict=zstd_dict ) self._pos = 0 else: - raise ValueError("Invalid mode: {!r}".format(mode)) + raise ValueError(f"Invalid mode: {mode!r}") # File object if isinstance(filename, (str, bytes, PathLike)): @@ -97,7 +97,7 @@ def __init__( elif hasattr(filename, "read") or hasattr(filename, "write"): self._fp = filename else: - raise TypeError(("filename must be a str, bytes, file or PathLike object")) + raise TypeError("filename must be a str, bytes, file or PathLike object") self._mode = mode_code if self._mode == _MODE_READ: @@ -358,7 +358,7 @@ def open( if "t" in mode: if "b" in mode: - raise ValueError("Invalid mode: %r" % (mode,)) + raise ValueError(f"Invalid mode: {mode!r}") else: if encoding is not None: raise ValueError("Argument 'encoding' not supported in binary mode") From 7d54d357ab2d07e133f7f5cbe62c41df1f0a4641 Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Sun, 4 May 2025 13:38:36 -0700 Subject: [PATCH 30/55] Formatting fixes in zstd tests Co-authored-by: Tomas R. --- Lib/test/test_zstd/test_core.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/Lib/test/test_zstd/test_core.py b/Lib/test/test_zstd/test_core.py index 15e34edf40c34a..b8c6e770090bc9 100644 --- a/Lib/test/test_zstd/test_core.py +++ b/Lib/test/test_zstd/test_core.py @@ -301,7 +301,7 @@ def test_zstd_multithread_compress(self): b = THIS_FILE_BYTES * (size // len(THIS_FILE_BYTES)) options = {CParameter.compression_level : 4, - CParameter.nb_workers : 2} + CParameter.nb_workers : 2} # compress() dat1 = compress(b, options=options) @@ -317,8 +317,7 @@ def test_zstd_multithread_compress(self): self.assertEqual(dat4, b * 3) # ZstdFile - with ZstdFile(io.BytesIO(), 'w', - options=options) as f: + with ZstdFile(io.BytesIO(), 'w', options=options) as f: f.write(b) def test_compress_flushblock(self): @@ -331,7 +330,7 @@ def test_compress_flushblock(self): dat1 += c.compress(THIS_FILE_BYTES[point:], c.FLUSH_BLOCK) self.assertEqual(c.last_mode, c.FLUSH_BLOCK) dat2 = c.flush() - pattern = r"Compressed data ended before the end-of-stream marker" + pattern = "Compressed data ended before the end-of-stream marker" with self.assertRaisesRegex(ZstdError, pattern): decompress(dat1) From 03795ec8a30288d5360149842546cf1758500a7c Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Sun, 4 May 2025 13:53:46 -0700 Subject: [PATCH 31/55] Improve docstrings for (de)compress --- Lib/compression/zstd/__init__.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py index 0986f9329f2c02..82b560722de8ba 100644 --- a/Lib/compression/zstd/__init__.py +++ b/Lib/compression/zstd/__init__.py @@ -147,21 +147,27 @@ def finalize_dict(zstd_dict, samples, dict_size, level): def compress(data, level=None, options=None, zstd_dict=None): """Return Zstandard compressed *data* as bytes. - Refer to ZstdCompressor's docstring for a description of the - optional arguments *level*, *options*, and *zstd_dict*. - - For incremental compression, use an ZstdCompressor instead. + *level* is an int specifying the compression level to use, defaulting to + COMPRESSION_LEVEL_DEFAULT + *options* is a dict object that contains advanced compression + parameters. See CParameter for more on options. + *zstd_dict* is a ZstdDict object, a pre-trained Zstandard dictionary. See + the function train_dict for how to train a ZstdDict on sample data. + + For incremental compression, use a ZstdCompressor instead. """ comp = ZstdCompressor(level=level, options=options, zstd_dict=zstd_dict) return comp.compress(data, mode=ZstdCompressor.FLUSH_FRAME) def decompress(data, zstd_dict=None, options=None): - """Decompress one or more frames of data. + """Decompress one or more frames of Zstandard compressed *data*. - Refer to ZstdDecompressor's docstring for a description of the - optional arguments *zstd_dict*, *options*. + *zstd_dict* is a ZstdDict object, a pre-trained Zstandard dictionary. See + the function train_dict for how to train a ZstdDict on sample data. + *options* is a dict object that contains advanced compression + parameters. See DParameter for more on options. - For incremental decompression, use an ZstdDecompressor instead. + For incremental decompression, use a ZstdDecompressor instead. """ results = [] while True: From 01fcfcb55cd4d3e9e3ba7190e6241ae2c4ec9c89 Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Sun, 4 May 2025 14:00:50 -0700 Subject: [PATCH 32/55] Fix some line length issues --- Lib/compression/zstd/__init__.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py index 82b560722de8ba..b2bd00725efc34 100644 --- a/Lib/compression/zstd/__init__.py +++ b/Lib/compression/zstd/__init__.py @@ -62,17 +62,17 @@ def __setattr__(self, name, _): def get_frame_info(frame_buffer): """Get zstd frame information from a frame header. - *frame_buffer* is a bytes-like object. It should starts from the beginning of - a frame, and needs to include at least the frame header (6 to 18 bytes). + *frame_buffer* is a bytes-like object. It should starts from the beginning + of a frame, and needs to include at least the frame header (6 to 18 bytes). Return a FrameInfo dataclass, which currently has two attributes 'decompressed_size' is the size in bytes of the data in the frame when decompressed, or None when the decompressed size is unknown. - 'dictionary_id' is a 32-bit unsigned integer value. 0 means dictionary ID was - not recorded in the frame header, the frame may or may not need a dictionary - to be decoded, and the ID of such a dictionary is not specified. + 'dictionary_id' is a 32-bit unsigned integer value. 0 means dictionary ID + was not recorded in the frame header, the frame may or may not need a + dictionary to be decoded, and the ID of such a dictionary is not specified. """ return FrameInfo(*_zstd._get_frame_info(frame_buffer)) @@ -138,7 +138,8 @@ def finalize_dict(zstd_dict, samples, dict_size, level): chunks = b''.join(samples) chunk_sizes = tuple(_nbytes(sample) for sample in samples) if not chunks: - raise ValueError("The samples are empty content, can't finalize dictionary.") + raise ValueError("The samples are empty content, can't finalize" + "dictionary.") dict_content = _finalize_dict(zstd_dict.dict_content, chunks, chunk_sizes, dict_size, level) From f04494c6618f1347d11c13054045dde310b20fd0 Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Sun, 4 May 2025 14:06:00 -0700 Subject: [PATCH 33/55] Improve docstring on C/DParameter.bounds() --- Lib/compression/zstd/__init__.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py index b2bd00725efc34..087bff94ec4cbb 100644 --- a/Lib/compression/zstd/__init__.py +++ b/Lib/compression/zstd/__init__.py @@ -210,7 +210,11 @@ class CParameter(enum.IntEnum): overlap_log = _zstd._ZSTD_c_overlapLog def bounds(self): - """Return lower and upper bounds of a compression parameter, both inclusive.""" + """Returns a tuple of ints (lower, upper), representing the bounds of a + compression parameter. + + Both lower and upper bounds are inclusive. + """ return _zstd._get_param_bounds(is_compress=True, parameter=self.value) @@ -220,7 +224,11 @@ class DParameter(enum.IntEnum): window_log_max = _zstd._ZSTD_d_windowLogMax def bounds(self): - """Return lower and upper bounds of a decompression parameter, both inclusive.""" + """Returns a tuple of ints (lower, upper) representing the bounds of a + decompression parameter. + + Both lower and upper bounds are inclusive. + """ return _zstd._get_param_bounds(is_compress=False, parameter=self.value) From caa40b192ca63dbcfa98ea4ad5ed9ed00b77677d Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Sun, 4 May 2025 14:13:21 -0700 Subject: [PATCH 34/55] Improve docstrings and formatting --- Lib/compression/zstd/__init__.py | 4 ++++ Lib/compression/zstd/_zstdfile.py | 33 ++++++++++++++++++------------- 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py index 087bff94ec4cbb..511e840837325d 100644 --- a/Lib/compression/zstd/__init__.py +++ b/Lib/compression/zstd/__init__.py @@ -89,6 +89,7 @@ def train_dict(samples, dict_size): *samples* is an iterable of samples, where a sample is a bytes-like object representing a file. + *dict_size* is the dictionary's maximum size, in bytes. """ if not isinstance(dict_size, int): @@ -150,8 +151,10 @@ def compress(data, level=None, options=None, zstd_dict=None): *level* is an int specifying the compression level to use, defaulting to COMPRESSION_LEVEL_DEFAULT + *options* is a dict object that contains advanced compression parameters. See CParameter for more on options. + *zstd_dict* is a ZstdDict object, a pre-trained Zstandard dictionary. See the function train_dict for how to train a ZstdDict on sample data. @@ -165,6 +168,7 @@ def decompress(data, zstd_dict=None, options=None): *zstd_dict* is a ZstdDict object, a pre-trained Zstandard dictionary. See the function train_dict for how to train a ZstdDict on sample data. + *options* is a dict object that contains advanced compression parameters. See DParameter for more on options. diff --git a/Lib/compression/zstd/_zstdfile.py b/Lib/compression/zstd/_zstdfile.py index 0fd31254cee19f..59ec80bf105c2f 100644 --- a/Lib/compression/zstd/_zstdfile.py +++ b/Lib/compression/zstd/_zstdfile.py @@ -3,7 +3,8 @@ from os import PathLike -from _zstd import ZstdCompressor, ZstdDecompressor, _ZSTD_DStreamSizes, ZstdError +from _zstd import (ZstdCompressor, ZstdDecompressor, _ZSTD_DStreamSizes, + ZstdError) from compression._common import _streams __all__ = ("ZstdFile", "open") @@ -50,13 +51,15 @@ def __init__( creating exclusively, or "a" for appending. These can equivalently be given as "rb", "wb", "xb" and "ab" respectively. - Parameters - level: The compression level to use, defaults to ZSTD_CLEVEL_DEFAULT. Note, - in read mode (decompression), compression level is not supported. - options: A dict object, containing advanced compression - parameters. - zstd_dict: A ZstdDict object, pre-trained dictionary for compression / - decompression. + + *level* is an int specifying the compression level to use, defaulting to + COMPRESSION_LEVEL_DEFAULT + + *options* is a dict object that contains advanced compression + parameters. See CParameter or DParameter for more on options. + + *zstd_dict* is a ZstdDict object, a pre-trained Zstandard dictionary. + See the function train_dict for how to train a ZstdDict on sample data. """ self._fp = None self._closefp = False @@ -68,12 +71,13 @@ def __init__( raise TypeError( ( "In read mode (decompression), options argument " - "should be a dict object, that represents decompression " - "options." + "should be a dict object, that represents " + "decompression options." ) ) if level is not None: - raise TypeError("level argument should only be passed when writing.") + raise TypeError("level argument should only be passed when " + "writing.") mode_code = _MODE_READ elif mode in ("w", "wb", "a", "ab", "x", "xb"): if not isinstance(level, (type(None), int)): @@ -97,7 +101,8 @@ def __init__( elif hasattr(filename, "read") or hasattr(filename, "write"): self._fp = filename else: - raise TypeError("filename must be a str, bytes, file or PathLike object") + raise TypeError("filename must be a str, bytes, file or PathLike " + "object") self._mode = mode_code if self._mode == _MODE_READ: @@ -137,7 +142,7 @@ def close(self): self._closefp = False def write(self, data): - """Write a bytes-like object to the file. + """Write a bytes-like object *data* to the file. Returns the number of uncompressed bytes written, which is always the length of data in bytes. Note that due to buffering, @@ -160,7 +165,7 @@ def write(self, data): def flush(self, mode=FLUSH_BLOCK): """Flush remaining data to the underlying stream. - The mode argument can be ZstdFile.FLUSH_BLOCK, ZstdFile.FLUSH_FRAME. + The mode argument can be ZstdFile.FLUSH_BLOCK or ZstdFile.FLUSH_FRAME. Abuse of this method will reduce compression ratio, use it only when necessary. From 4584ec5d76761d0489f2797d5efca0e4150c63ad Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Sun, 4 May 2025 14:13:36 -0700 Subject: [PATCH 35/55] Add missing f string prefix --- Lib/compression/zstd/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py index 511e840837325d..f95fff5fdc787a 100644 --- a/Lib/compression/zstd/__init__.py +++ b/Lib/compression/zstd/__init__.py @@ -94,7 +94,7 @@ def train_dict(samples, dict_size): """ if not isinstance(dict_size, int): ds_cls = type(dict_size).__qualname__ - raise TypeError('dict_size must be an int object, not {ds_cls!r}.') + raise TypeError(f'dict_size must be an int object, not {ds_cls!r}.') samples = tuple(samples) chunks = b''.join(samples) From 3cafdc6ebd2c61e9a6642060cd84f619ead75972 Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Sun, 4 May 2025 14:17:24 -0700 Subject: [PATCH 36/55] Fix weird indent in _zstdfile.py --- Lib/compression/zstd/_zstdfile.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/Lib/compression/zstd/_zstdfile.py b/Lib/compression/zstd/_zstdfile.py index 59ec80bf105c2f..db788e14f1f8b3 100644 --- a/Lib/compression/zstd/_zstdfile.py +++ b/Lib/compression/zstd/_zstdfile.py @@ -51,15 +51,14 @@ def __init__( creating exclusively, or "a" for appending. These can equivalently be given as "rb", "wb", "xb" and "ab" respectively. - - *level* is an int specifying the compression level to use, defaulting to - COMPRESSION_LEVEL_DEFAULT + level is an int specifying the compression level to use, defaulting to + COMPRESSION_LEVEL_DEFAULT - *options* is a dict object that contains advanced compression - parameters. See CParameter or DParameter for more on options. + options is a dict object that contains advanced compression + parameters. See CParameter or DParameter for more on options. - *zstd_dict* is a ZstdDict object, a pre-trained Zstandard dictionary. - See the function train_dict for how to train a ZstdDict on sample data. + zstd_dict is a ZstdDict object, a pre-trained Zstandard dictionary. + See the function train_dict for how to train a ZstdDict on sample data. """ self._fp = None self._closefp = False From 8cb0846742435e5bb0c1b9a7e77f85754ac34f95 Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Sun, 4 May 2025 18:21:53 -0700 Subject: [PATCH 37/55] Use io.open instead of builtins.open Co-authored-by: Adam Turner <9087854+AA-Turner@users.noreply.github.com> --- Lib/compression/zstd/_zstdfile.py | 3 +-- Lib/test/test_zstd/test_core.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/Lib/compression/zstd/_zstdfile.py b/Lib/compression/zstd/_zstdfile.py index db788e14f1f8b3..d1d8d2bd2ce8b8 100644 --- a/Lib/compression/zstd/_zstdfile.py +++ b/Lib/compression/zstd/_zstdfile.py @@ -1,4 +1,3 @@ -import builtins import io from os import PathLike @@ -95,7 +94,7 @@ def __init__( if isinstance(filename, (str, bytes, PathLike)): if "b" not in mode: mode += "b" - self._fp = builtins.open(filename, mode) + self._fp = io.open(filename, mode) self._closefp = True elif hasattr(filename, "read") or hasattr(filename, "write"): self._fp = filename diff --git a/Lib/test/test_zstd/test_core.py b/Lib/test/test_zstd/test_core.py index b8c6e770090bc9..1a6e87c62c5714 100644 --- a/Lib/test/test_zstd/test_core.py +++ b/Lib/test/test_zstd/test_core.py @@ -3,7 +3,6 @@ import io import pathlib import random -import builtins import re import os import unittest @@ -95,7 +94,7 @@ def setUpModule(): b'a' * (32*_1K) global THIS_FILE_BYTES, THIS_FILE_STR - with builtins.open(os.path.abspath(__file__), 'rb') as f: + with io.open(os.path.abspath(__file__), 'rb') as f: THIS_FILE_BYTES = f.read() THIS_FILE_BYTES = re.sub(rb'\r?\n', rb'\n', THIS_FILE_BYTES) THIS_FILE_STR = THIS_FILE_BYTES.decode('utf-8') From c7d5d67505f7eb7fca6111fc35319c6665415d16 Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Sun, 4 May 2025 18:23:08 -0700 Subject: [PATCH 38/55] Remove _READER_CLASS from ZstdFile Co-authored-by: Adam Turner <9087854+AA-Turner@users.noreply.github.com> --- Lib/compression/zstd/_zstdfile.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/Lib/compression/zstd/_zstdfile.py b/Lib/compression/zstd/_zstdfile.py index d1d8d2bd2ce8b8..0e8d6a0f26586a 100644 --- a/Lib/compression/zstd/_zstdfile.py +++ b/Lib/compression/zstd/_zstdfile.py @@ -26,8 +26,6 @@ class ZstdFile(_streams.BaseStream): supports the Buffer Protocol. """ - _READER_CLASS = _streams.DecompressReader - FLUSH_BLOCK = ZstdCompressor.FLUSH_BLOCK FLUSH_FRAME = ZstdCompressor.FLUSH_FRAME @@ -104,7 +102,7 @@ def __init__( self._mode = mode_code if self._mode == _MODE_READ: - raw = self._READER_CLASS( + raw = _streams.DecompressReader( self._fp, ZstdDecompressor, trailing_error=ZstdError, From a56a22e1444b375f4c985bcb25a5a9b5fe1958d6 Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Sun, 4 May 2025 19:12:41 -0700 Subject: [PATCH 39/55] Adopt many suggestions from AA-Turner for ZstdFile * rename filename argument to file * improve __init__ mode and argument checking * docstring and error rewording * renamed self._closefp to self._close_fp * removed mode_code from __init__ * removed unneeded self._READER_CLASS Co-authored-by: Adam Turner <9087854+AA-Turner@users.noreply.github.com> --- Lib/compression/zstd/_zstdfile.py | 93 ++++++++++++++----------------- Lib/test/test_zstd/test_core.py | 5 +- 2 files changed, 44 insertions(+), 54 deletions(-) diff --git a/Lib/compression/zstd/_zstdfile.py b/Lib/compression/zstd/_zstdfile.py index 0e8d6a0f26586a..0ec76fba5fec92 100644 --- a/Lib/compression/zstd/_zstdfile.py +++ b/Lib/compression/zstd/_zstdfile.py @@ -31,7 +31,8 @@ class ZstdFile(_streams.BaseStream): def __init__( self, - filename, + file, + /, mode="r", *, level=None, @@ -40,7 +41,7 @@ def __init__( ): """Open a zstd compressed file in binary mode. - filename can be either an actual file name (given as a str, bytes, or + file can be either an actual file name (given as a str, bytes, or PathLike object), in which case the named file is opened, or it can be an existing file object to read from or write to. @@ -58,29 +59,23 @@ def __init__( See the function train_dict for how to train a ZstdDict on sample data. """ self._fp = None - self._closefp = False + self._close_fp = False self._mode = _MODE_CLOSED + if not isinstance(mode, str): + raise ValueError("mode must be a str") # Read or write mode - if mode in ("r", "rb"): - if not isinstance(options, (type(None), dict)): - raise TypeError( - ( - "In read mode (decompression), options argument " - "should be a dict object, that represents " - "decompression options." - ) - ) + if options is not None and not isinstance(options, dict): + raise TypeError("options must be a dict or None") + mode = mode.removesuffix("b") # handle rb, wb, xb, ab + if mode == "r": if level is not None: - raise TypeError("level argument should only be passed when " - "writing.") - mode_code = _MODE_READ - elif mode in ("w", "wb", "a", "ab", "x", "xb"): - if not isinstance(level, (type(None), int)): - raise TypeError("level argument should be an int object.") - if not isinstance(options, (type(None), dict)): - raise TypeError("options argument should be an dict object.") - mode_code = _MODE_WRITE + raise TypeError("level is illegal in read mode") + self._mode = _MODE_READ + elif mode in {"w", "a", "x"}: + if level is not None and not isinstance(level, int): + raise TypeError("level must be int or None") + self._mode = _MODE_WRITE self._compressor = ZstdCompressor( level=level, options=options, zstd_dict=zstd_dict ) @@ -89,17 +84,15 @@ def __init__( raise ValueError(f"Invalid mode: {mode!r}") # File object - if isinstance(filename, (str, bytes, PathLike)): - if "b" not in mode: - mode += "b" - self._fp = io.open(filename, mode) - self._closefp = True - elif hasattr(filename, "read") or hasattr(filename, "write"): - self._fp = filename + if isinstance(file, (str, bytes, PathLike)): + self._fp = io.open(file, f'{mode}b') + self._close_fp = True + elif ((mode == 'r' and hasattr(file, "read")) + or (mode != 'r' and hasattr(file, "write"))): + self._fp = file else: - raise TypeError("filename must be a str, bytes, file or PathLike " - "object") - self._mode = mode_code + raise TypeError("file must be a file-like object " + "or a str, bytes, or PathLike object") if self._mode == _MODE_READ: raw = _streams.DecompressReader( @@ -114,15 +107,14 @@ def __init__( def close(self): """Flush and close the file. - May be called more than once without error. Once the file is - closed, any other operation on it will raise a ValueError. + May be called multiple times. Once the file has been closed, + any other operation on it will raise ValueError. """ - # Nop if already closed if self._fp is None: return try: if self._mode == _MODE_READ: - if hasattr(self, "_buffer") and self._buffer: + if getattr(self, '_buffer', None): self._buffer.close() self._buffer = None elif self._mode == _MODE_WRITE: @@ -131,11 +123,11 @@ def close(self): finally: self._mode = _MODE_CLOSED try: - if self._closefp: + if self._close_fp: self._fp.close() finally: self._fp = None - self._closefp = False + self._close_fp = False def write(self, data): """Write a bytes-like object *data* to the file. @@ -161,9 +153,8 @@ def write(self, data): def flush(self, mode=FLUSH_BLOCK): """Flush remaining data to the underlying stream. - The mode argument can be ZstdFile.FLUSH_BLOCK or ZstdFile.FLUSH_FRAME. - Abuse of this method will reduce compression ratio, use it only when - necessary. + The mode argument can be FLUSH_BLOCK or FLUSH_FRAME. Abuse of this + method will reduce compression ratio, use it only when necessary. If the program is interrupted afterwards, all data can be recovered. To ensure saving to disk, also need to use os.fsync(fd). @@ -173,10 +164,10 @@ def flush(self, mode=FLUSH_BLOCK): if self._mode == _MODE_READ: return self._check_not_closed() - if mode not in (self.FLUSH_BLOCK, self.FLUSH_FRAME): - raise ValueError("mode argument wrong value, it should be " - "ZstdCompressor.FLUSH_FRAME or " - "ZstdCompressor.FLUSH_BLOCK.") + if mode not in {self.FLUSH_BLOCK, self.FLUSH_FRAME}: + raise ValueError("Invalid mode argument, expected either " + "ZstdFile.FLUSH_FRAME or " + "ZstdFile.FLUSH_BLOCK") if self._compressor.last_mode == mode: return # Flush zstd block/frame, and write. @@ -270,8 +261,7 @@ def peek(self, size=-1): return self._buffer.peek(size) def __next__(self): - ret = self._buffer.readline() - if ret: + if ret := self._buffer.readline(): return ret raise StopIteration @@ -319,7 +309,8 @@ def writable(self): # Copied from lzma module def open( - filename, + file, + /, mode="rb", *, level=None, @@ -331,9 +322,9 @@ def open( ): """Open a zstd compressed file in binary or text mode. - filename can be either an actual file name (given as a str, bytes, or - PathLike object), in which case the named file is opened, or it can be an - existing file object to read from or write to. + file can be either a file name (given as a str, bytes, or PathLike object), + in which case the named file is opened, or it can be an existing file object + to read from or write to. The mode parameter can be "r", "rb" (default), "w", "wb", "x", "xb", "a", "ab" for binary mode, or "rt", "wt", "xt", "at" for text mode. @@ -370,7 +361,7 @@ def open( zstd_mode = mode.replace("t", "") binary_file = ZstdFile( - filename, zstd_mode, level=level, options=options, zstd_dict=zstd_dict + file, zstd_mode, level=level, options=options, zstd_dict=zstd_dict ) if "t" in mode: diff --git a/Lib/test/test_zstd/test_core.py b/Lib/test/test_zstd/test_core.py index 1a6e87c62c5714..1dac944390125d 100644 --- a/Lib/test/test_zstd/test_core.py +++ b/Lib/test/test_zstd/test_core.py @@ -2121,10 +2121,9 @@ class T: def read(self, size): return b'a' * size - with self.assertRaises(AttributeError): # on close + with self.assertRaises(TypeError): # on creation with ZstdFile(T(), 'w') as f: - with self.assertRaises(AttributeError): # on write - f.write(b'1234') + pass # 3 with ZstdFile(io.BytesIO(), 'w') as f: From 7e919c8488c8aed97797fdf454c4d4816db210bc Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Sun, 4 May 2025 19:22:49 -0700 Subject: [PATCH 40/55] Set self._buffer to None --- Lib/compression/zstd/_zstdfile.py | 1 + 1 file changed, 1 insertion(+) diff --git a/Lib/compression/zstd/_zstdfile.py b/Lib/compression/zstd/_zstdfile.py index 0ec76fba5fec92..1414223684bf51 100644 --- a/Lib/compression/zstd/_zstdfile.py +++ b/Lib/compression/zstd/_zstdfile.py @@ -61,6 +61,7 @@ def __init__( self._fp = None self._close_fp = False self._mode = _MODE_CLOSED + self._buffer = None if not isinstance(mode, str): raise ValueError("mode must be a str") From 389faedb79d8b51f485166169e1bd18d388c09fd Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Mon, 5 May 2025 08:11:59 -0700 Subject: [PATCH 41/55] Move _nbytes to _zstdfile.py --- Lib/compression/zstd/__init__.py | 9 +-------- Lib/compression/zstd/_zstdfile.py | 15 +++++++++------ 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py index f95fff5fdc787a..0b1044c53471a5 100644 --- a/Lib/compression/zstd/__init__.py +++ b/Lib/compression/zstd/__init__.py @@ -28,7 +28,7 @@ import enum -from compression.zstd._zstdfile import ZstdFile, open +from compression.zstd._zstdfile import ZstdFile, open, _nbytes from _zstd import * import _zstd @@ -77,13 +77,6 @@ def get_frame_info(frame_buffer): return FrameInfo(*_zstd._get_frame_info(frame_buffer)) -def _nbytes(dat): - if isinstance(dat, (bytes, bytearray)): - return len(dat) - with memoryview(dat) as mv: - return mv.nbytes - - def train_dict(samples, dict_size): """Train a zstd dictionary, return a ZstdDict object. diff --git a/Lib/compression/zstd/_zstdfile.py b/Lib/compression/zstd/_zstdfile.py index 1414223684bf51..9c5f8db2d7e296 100644 --- a/Lib/compression/zstd/_zstdfile.py +++ b/Lib/compression/zstd/_zstdfile.py @@ -15,6 +15,13 @@ _MODE_WRITE = 2 +def _nbytes(dat): + if isinstance(dat, (bytes, bytearray)): + return len(dat) + with memoryview(dat) as mv: + return mv.nbytes + + class ZstdFile(_streams.BaseStream): """A file object providing transparent zstd (de)compression. @@ -139,12 +146,8 @@ def write(self, data): or .close() is called. """ self._check_can_write() - if isinstance(data, (bytes, bytearray)): - length = len(data) - else: - # accept any data that supports the buffer protocol - data = memoryview(data) - length = data.nbytes + + length = _nbytes(data) compressed = self._compressor.compress(data) self._fp.write(compressed) From 006ef2ece93e0d17abbd8b37754e9f7e0729c516 Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Mon, 5 May 2025 08:14:57 -0700 Subject: [PATCH 42/55] Move test_zstd to file --- Lib/test/{test_zstd/test_core.py => test_zstd.py} | 0 Lib/test/test_zstd/__init__.py | 5 ----- Lib/test/test_zstd/__main__.py | 7 ------- 3 files changed, 12 deletions(-) rename Lib/test/{test_zstd/test_core.py => test_zstd.py} (100%) delete mode 100644 Lib/test/test_zstd/__init__.py delete mode 100644 Lib/test/test_zstd/__main__.py diff --git a/Lib/test/test_zstd/test_core.py b/Lib/test/test_zstd.py similarity index 100% rename from Lib/test/test_zstd/test_core.py rename to Lib/test/test_zstd.py diff --git a/Lib/test/test_zstd/__init__.py b/Lib/test/test_zstd/__init__.py deleted file mode 100644 index 4b16ecc31156a5..00000000000000 --- a/Lib/test/test_zstd/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -import os -from test.support import load_package_tests - -def load_tests(*args): - return load_package_tests(os.path.dirname(__file__), *args) diff --git a/Lib/test/test_zstd/__main__.py b/Lib/test/test_zstd/__main__.py deleted file mode 100644 index e25ac946edffe4..00000000000000 --- a/Lib/test/test_zstd/__main__.py +++ /dev/null @@ -1,7 +0,0 @@ -import unittest - -from . import load_tests # noqa: F401 - - -if __name__ == "__main__": - unittest.main() From c846b78d73180b65e1ff08aa051b9bd9969cd8b9 Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Mon, 5 May 2025 08:21:48 -0700 Subject: [PATCH 43/55] Rename C/DParameter to (De)CompressionParameter --- Lib/compression/zstd/__init__.py | 16 ++-- Lib/compression/zstd/_zstdfile.py | 3 +- Lib/test/test_zstd.py | 135 +++++++++++++++--------------- Modules/_zstd/_zstdmodule.c | 14 ++-- Modules/_zstd/compressor.c | 4 +- Modules/_zstd/decompressor.c | 4 +- 6 files changed, 90 insertions(+), 86 deletions(-) diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py index 0b1044c53471a5..bd95105c44f5a0 100644 --- a/Lib/compression/zstd/__init__.py +++ b/Lib/compression/zstd/__init__.py @@ -4,9 +4,9 @@ # compression.zstd "COMPRESSION_LEVEL_DEFAULT", "compress", - "CParameter", + "CompressionParameter", "decompress", - "DParameter", + "DecompressionParameter", "finalize_dict", "get_frame_info", "Strategy", @@ -146,7 +146,7 @@ def compress(data, level=None, options=None, zstd_dict=None): COMPRESSION_LEVEL_DEFAULT *options* is a dict object that contains advanced compression - parameters. See CParameter for more on options. + parameters. See CompressionParameter for more on options. *zstd_dict* is a ZstdDict object, a pre-trained Zstandard dictionary. See the function train_dict for how to train a ZstdDict on sample data. @@ -163,7 +163,7 @@ def decompress(data, zstd_dict=None, options=None): the function train_dict for how to train a ZstdDict on sample data. *options* is a dict object that contains advanced compression - parameters. See DParameter for more on options. + parameters. See DecompressionParameter for more on options. For incremental decompression, use a ZstdDecompressor instead. """ @@ -180,7 +180,7 @@ def decompress(data, zstd_dict=None, options=None): return b"".join(results) -class CParameter(enum.IntEnum): +class CompressionParameter(enum.IntEnum): """Compression parameters""" compression_level = _zstd._ZSTD_c_compressionLevel @@ -215,7 +215,7 @@ def bounds(self): return _zstd._get_param_bounds(is_compress=True, parameter=self.value) -class DParameter(enum.IntEnum): +class DecompressionParameter(enum.IntEnum): """Decompression parameters""" window_log_max = _zstd._ZSTD_d_windowLogMax @@ -248,5 +248,5 @@ class Strategy(enum.IntEnum): btultra2 = _zstd._ZSTD_btultra2 -# Set CParameter/DParameter types for validity check -_zstd._set_parameter_types(CParameter, DParameter) +# Set CompressionParameter/DecompressionParameter types for validity check +_zstd._set_parameter_types(CompressionParameter, DecompressionParameter) diff --git a/Lib/compression/zstd/_zstdfile.py b/Lib/compression/zstd/_zstdfile.py index 9c5f8db2d7e296..366a3e48930f33 100644 --- a/Lib/compression/zstd/_zstdfile.py +++ b/Lib/compression/zstd/_zstdfile.py @@ -60,7 +60,8 @@ def __init__( COMPRESSION_LEVEL_DEFAULT options is a dict object that contains advanced compression - parameters. See CParameter or DParameter for more on options. + parameters. See CompressionParameter or DecompressionParameter for more + information on options. zstd_dict is a ZstdDict object, a pre-trained Zstandard dictionary. See the function train_dict for how to train a ZstdDict on sample data. diff --git a/Lib/test/test_zstd.py b/Lib/test/test_zstd.py index 1dac944390125d..24d4190872054e 100644 --- a/Lib/test/test_zstd.py +++ b/Lib/test/test_zstd.py @@ -32,8 +32,8 @@ get_frame_size, finalize_dict, train_dict, - CParameter, - DParameter, + CompressionParameter, + DecompressionParameter, Strategy, ZstdFile, ) @@ -67,14 +67,14 @@ def setUpModule(): global SUPPORT_MULTITHREADING - SUPPORT_MULTITHREADING = CParameter.nb_workers.bounds() != (0, 0) + SUPPORT_MULTITHREADING = CompressionParameter.nb_workers.bounds() != (0, 0) # uncompressed size 130KB, more than a zstd block. # with a frame epilogue, 4 bytes checksum. global DAT_130K_D DAT_130K_D = bytes([random.randint(0, 127) for _ in range(130*_1K)]) global DAT_130K_C - DAT_130K_C = compress(DAT_130K_D, options={CParameter.checksum_flag:1}) + DAT_130K_C = compress(DAT_130K_D, options={CompressionParameter.checksum_flag:1}) global DECOMPRESSED_DAT DECOMPRESSED_DAT = b'abcdefg123456' * 1000 @@ -132,7 +132,7 @@ def test_version(self): self.assertEqual(s, zstd_version) def test_compressionLevel_values(self): - min, max = CParameter.compression_level.bounds() + min, max = CompressionParameter.compression_level.bounds() self.assertIs(type(COMPRESSION_LEVEL_DEFAULT), int) self.assertIs(type(min), int) self.assertIs(type(max), int) @@ -146,7 +146,7 @@ def test_roundtrip_default(self): def test_roundtrip_level(self): raw_dat = THIS_FILE_BYTES[: len(THIS_FILE_BYTES) // 6] - level_min, level_max = CParameter.compression_level.bounds() + level_min, level_max = CompressionParameter.compression_level.bounds() for level in range(max(-20, level_min), level_max + 1): dat1 = compress(raw_dat, level) @@ -202,7 +202,7 @@ def test_simple_compress_bad_args(self): ZstdCompressor(options={2**31: 100}) with self.assertRaises(ZstdError): - ZstdCompressor(options={CParameter.window_log: 100}) + ZstdCompressor(options={CompressionParameter.window_log: 100}) with self.assertRaises(ZstdError): ZstdCompressor(options={3333: 100}) @@ -228,56 +228,56 @@ def test_simple_compress_bad_args(self): zc.flush(zc.FLUSH_FRAME) def test_compress_parameters(self): - d = {CParameter.compression_level : 10, - - CParameter.window_log : 12, - CParameter.hash_log : 10, - CParameter.chain_log : 12, - CParameter.search_log : 12, - CParameter.min_match : 4, - CParameter.target_length : 12, - CParameter.strategy : Strategy.lazy, - - CParameter.enable_long_distance_matching : 1, - CParameter.ldm_hash_log : 12, - CParameter.ldm_min_match : 11, - CParameter.ldm_bucket_size_log : 5, - CParameter.ldm_hash_rate_log : 12, - - CParameter.content_size_flag : 1, - CParameter.checksum_flag : 1, - CParameter.dict_id_flag : 0, - - CParameter.nb_workers : 2 if SUPPORT_MULTITHREADING else 0, - CParameter.job_size : 5*_1M if SUPPORT_MULTITHREADING else 0, - CParameter.overlap_log : 9 if SUPPORT_MULTITHREADING else 0, + d = {CompressionParameter.compression_level : 10, + + CompressionParameter.window_log : 12, + CompressionParameter.hash_log : 10, + CompressionParameter.chain_log : 12, + CompressionParameter.search_log : 12, + CompressionParameter.min_match : 4, + CompressionParameter.target_length : 12, + CompressionParameter.strategy : Strategy.lazy, + + CompressionParameter.enable_long_distance_matching : 1, + CompressionParameter.ldm_hash_log : 12, + CompressionParameter.ldm_min_match : 11, + CompressionParameter.ldm_bucket_size_log : 5, + CompressionParameter.ldm_hash_rate_log : 12, + + CompressionParameter.content_size_flag : 1, + CompressionParameter.checksum_flag : 1, + CompressionParameter.dict_id_flag : 0, + + CompressionParameter.nb_workers : 2 if SUPPORT_MULTITHREADING else 0, + CompressionParameter.job_size : 5*_1M if SUPPORT_MULTITHREADING else 0, + CompressionParameter.overlap_log : 9 if SUPPORT_MULTITHREADING else 0, } ZstdCompressor(options=d) # larger than signed int, ValueError d1 = d.copy() - d1[CParameter.ldm_bucket_size_log] = 2**31 + d1[CompressionParameter.ldm_bucket_size_log] = 2**31 self.assertRaises(ValueError, ZstdCompressor, options=d1) # clamp compressionLevel - level_min, level_max = CParameter.compression_level.bounds() + level_min, level_max = CompressionParameter.compression_level.bounds() compress(b'', level_max+1) compress(b'', level_min-1) - compress(b'', options={CParameter.compression_level:level_max+1}) - compress(b'', options={CParameter.compression_level:level_min-1}) + compress(b'', options={CompressionParameter.compression_level:level_max+1}) + compress(b'', options={CompressionParameter.compression_level:level_min-1}) # zstd lib doesn't support MT compression if not SUPPORT_MULTITHREADING: with self.assertRaises(ZstdError): - ZstdCompressor(options={CParameter.nb_workers:4}) + ZstdCompressor(options={CompressionParameter.nb_workers:4}) with self.assertRaises(ZstdError): - ZstdCompressor(options={CParameter.job_size:4}) + ZstdCompressor(options={CompressionParameter.job_size:4}) with self.assertRaises(ZstdError): - ZstdCompressor(options={CParameter.overlap_log:4}) + ZstdCompressor(options={CompressionParameter.overlap_log:4}) # out of bounds error msg - option = {CParameter.window_log:100} + option = {CompressionParameter.window_log:100} with self.assertRaisesRegex(ZstdError, (r'Error when setting zstd compression parameter "window_log", ' r'it should \d+ <= value <= \d+, provided value is 100\. ' @@ -286,7 +286,7 @@ def test_compress_parameters(self): def test_unknown_compression_parameter(self): KEY = 100001234 - option = {CParameter.compression_level: 10, + option = {CompressionParameter.compression_level: 10, KEY: 200000000} pattern = r'Zstd compression parameter.*?"unknown parameter \(key %d\)"' \ % KEY @@ -299,8 +299,8 @@ def test_zstd_multithread_compress(self): size = 40*_1M b = THIS_FILE_BYTES * (size // len(THIS_FILE_BYTES)) - options = {CParameter.compression_level : 4, - CParameter.nb_workers : 2} + options = {CompressionParameter.compression_level : 4, + CompressionParameter.nb_workers : 2} # compress() dat1 = compress(b, options=options) @@ -388,7 +388,7 @@ def test_simple_decompress_bad_args(self): ZstdDecompressor(options={2**31 : 100}) with self.assertRaises(ZstdError): - ZstdDecompressor(options={DParameter.window_log_max:100}) + ZstdDecompressor(options={DecompressionParameter.window_log_max:100}) with self.assertRaises(ZstdError): ZstdDecompressor(options={3333 : 100}) @@ -400,16 +400,16 @@ def test_simple_decompress_bad_args(self): lzd.decompress(empty) def test_decompress_parameters(self): - d = {DParameter.window_log_max : 15} + d = {DecompressionParameter.window_log_max : 15} ZstdDecompressor(options=d) # larger than signed int, ValueError d1 = d.copy() - d1[DParameter.window_log_max] = 2**31 + d1[DecompressionParameter.window_log_max] = 2**31 self.assertRaises(ValueError, ZstdDecompressor, None, d1) # out of bounds error msg - options = {DParameter.window_log_max:100} + options = {DecompressionParameter.window_log_max:100} with self.assertRaisesRegex(ZstdError, (r'Error when setting zstd decompression parameter "window_log_max", ' r'it should \d+ <= value <= \d+, provided value is 100\. ' @@ -418,7 +418,7 @@ def test_decompress_parameters(self): def test_unknown_decompression_parameter(self): KEY = 100001234 - options = {DParameter.window_log_max: DParameter.window_log_max.bounds()[1], + options = {DecompressionParameter.window_log_max: DecompressionParameter.window_log_max.bounds()[1], KEY: 200000000} pattern = r'Zstd decompression parameter.*?"unknown parameter \(key %d\)"' \ % KEY @@ -517,7 +517,7 @@ def test_decompressor_arg(self): ZstdDecompressor() ZstdDecompressor(zd, {}) - ZstdDecompressor(zstd_dict=zd, options={DParameter.window_log_max:25}) + ZstdDecompressor(zstd_dict=zd, options={DecompressionParameter.window_log_max:25}) def test_decompressor_1(self): # empty @@ -662,7 +662,7 @@ class DecompressorFlagsTestCase(unittest.TestCase): @classmethod def setUpClass(cls): - options = {CParameter.checksum_flag:1} + options = {CompressionParameter.checksum_flag:1} c = ZstdCompressor(options=options) cls.DECOMPRESSED_42 = b'a'*42 @@ -1294,9 +1294,9 @@ def test_as_digested_dict(self): zd.as_undigested_dict = b'1234' def test_advanced_compression_parameters(self): - options = {CParameter.compression_level: 6, - CParameter.window_log: 20, - CParameter.enable_long_distance_matching: 1} + options = {CompressionParameter.compression_level: 6, + CompressionParameter.window_log: 20, + CompressionParameter.enable_long_distance_matching: 1} # automatically select dat = compress(SAMPLES[0], options=options, zstd_dict=TRAINED_DICT) @@ -1327,14 +1327,14 @@ def test_init(self): with ZstdFile(io.BytesIO(), "w", level=12) as f: pass - with ZstdFile(io.BytesIO(), "w", options={CParameter.checksum_flag:1}) as f: + with ZstdFile(io.BytesIO(), "w", options={CompressionParameter.checksum_flag:1}) as f: pass with ZstdFile(io.BytesIO(), "w", options={}) as f: pass with ZstdFile(io.BytesIO(), "w", level=20, zstd_dict=TRAINED_DICT) as f: pass - with ZstdFile(io.BytesIO(), "r", options={DParameter.window_log_max:25}) as f: + with ZstdFile(io.BytesIO(), "r", options={DecompressionParameter.window_log_max:25}) as f: pass with ZstdFile(io.BytesIO(), "r", options={}, zstd_dict=TRAINED_DICT) as f: pass @@ -1420,10 +1420,13 @@ def test_init_bad_mode(self): with self.assertRaises(ValueError): ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rw") - with self.assertRaisesRegex(TypeError, r"NOT be CParameter"): - ZstdFile(io.BytesIO(), 'rb', options={CParameter.compression_level:5}) - with self.assertRaisesRegex(TypeError, r"NOT be DParameter"): - ZstdFile(io.BytesIO(), 'wb', options={DParameter.window_log_max:21}) + with self.assertRaisesRegex(TypeError, r"NOT be CompressionParameter"): + ZstdFile(io.BytesIO(), 'rb', + options={CompressionParameter.compression_level:5}) + with self.assertRaisesRegex(TypeError, + r"NOT be DecompressionParameter"): + ZstdFile(io.BytesIO(), 'wb', + options={DecompressionParameter.window_log_max:21}) with self.assertRaises(TypeError): ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r", level=12) @@ -1435,14 +1438,14 @@ def test_init_bad_check(self): with self.assertRaises(ZstdError): ZstdFile(io.BytesIO(), "w", options={999:9999}) with self.assertRaises(ZstdError): - ZstdFile(io.BytesIO(), "w", options={CParameter.window_log:99}) + ZstdFile(io.BytesIO(), "w", options={CompressionParameter.window_log:99}) with self.assertRaises(TypeError): ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r", options=33) with self.assertRaises(ValueError): ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), - options={DParameter.window_log_max:2**31}) + options={DecompressionParameter.window_log_max:2**31}) with self.assertRaises(ZstdError): ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), @@ -1639,7 +1642,7 @@ def test_read_0(self): self.assertEqual(f.read(0), b"") self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB) with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), - options={DParameter.window_log_max:20}) as f: + options={DecompressionParameter.window_log_max:20}) as f: self.assertEqual(f.read(0), b"") # empty file @@ -1854,16 +1857,16 @@ def test_write(self): self.assertEqual(dst.getvalue(), expected) with io.BytesIO() as dst: - with ZstdFile(dst, "w", options={CParameter.checksum_flag:1}) as f: + with ZstdFile(dst, "w", options={CompressionParameter.checksum_flag:1}) as f: f.write(raw_data) - comp = ZstdCompressor(options={CParameter.checksum_flag:1}) + comp = ZstdCompressor(options={CompressionParameter.checksum_flag:1}) expected = comp.compress(raw_data) + comp.flush() self.assertEqual(dst.getvalue(), expected) with io.BytesIO() as dst: - options = {CParameter.compression_level:-5, - CParameter.checksum_flag:1} + options = {CompressionParameter.compression_level:-5, + CompressionParameter.checksum_flag:1} with ZstdFile(dst, "w", options=options) as f: f.write(raw_data) @@ -2321,11 +2324,11 @@ def test_bad_params(self): os.remove(TESTFN) def test_option(self): - options = {DParameter.window_log_max:25} + options = {DecompressionParameter.window_log_max:25} with open(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rb", options=options) as f: self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB) - options = {CParameter.compression_level:12} + options = {CompressionParameter.compression_level:12} with io.BytesIO() as bio: with open(bio, "wb", options=options) as f: f.write(DECOMPRESSED_100_PLUS_32KB) diff --git a/Modules/_zstd/_zstdmodule.c b/Modules/_zstd/_zstdmodule.c index 1c122466f671ae..a07deeee85a95a 100644 --- a/Modules/_zstd/_zstdmodule.c +++ b/Modules/_zstd/_zstdmodule.c @@ -402,11 +402,11 @@ _zstd__finalize_dict_impl(PyObject *module, PyBytesObject *custom_dict_bytes, _zstd._get_param_bounds is_compress: bool - True for CParameter, False for DParameter. + True for CompressionParameter, False for DecompressionParameter. parameter: int The parameter to get bounds. -Internal function, get CParameter/DParameter bounds. +Internal function, get CompressionParameter/DecompressionParameter bounds. [clinic start generated code]*/ static PyObject * @@ -514,11 +514,11 @@ _zstd__get_frame_info_impl(PyObject *module, Py_buffer *frame_buffer) _zstd._set_parameter_types c_parameter_type: object(subclass_of='&PyType_Type') - CParameter IntEnum type object + CompressionParameter IntEnum type object d_parameter_type: object(subclass_of='&PyType_Type') - DParameter IntEnum type object + DecompressionParameter IntEnum type object -Internal function, set CParameter/DParameter types for validity check. +Internal function, set CompressionParameter/DecompressionParameter types for validity check. [clinic start generated code]*/ static PyObject * @@ -530,8 +530,8 @@ _zstd__set_parameter_types_impl(PyObject *module, PyObject *c_parameter_type, if (!PyType_Check(c_parameter_type) || !PyType_Check(d_parameter_type)) { PyErr_SetString(PyExc_ValueError, - "The two arguments should be CParameter and " - "DParameter types."); + "The two arguments should be CompressionParameter and " + "DecompressionParameter types."); return NULL; } diff --git a/Modules/_zstd/compressor.c b/Modules/_zstd/compressor.c index d0f677be821572..b735981e7476d5 100644 --- a/Modules/_zstd/compressor.c +++ b/Modules/_zstd/compressor.c @@ -71,14 +71,14 @@ _PyZstd_set_c_parameters(ZstdCompressor *self, PyObject *level_or_options, if (Py_TYPE(key) == mod_state->DParameter_type) { PyErr_SetString(PyExc_TypeError, "Key of compression option dict should " - "NOT be DParameter."); + "NOT be DecompressionParameter."); return -1; } int key_v = PyLong_AsInt(key); if (key_v == -1 && PyErr_Occurred()) { PyErr_SetString(PyExc_ValueError, - "Key of options dict should be a CParameter attribute."); + "Key of options dict should be a CompressionParameter attribute."); return -1; } diff --git a/Modules/_zstd/decompressor.c b/Modules/_zstd/decompressor.c index 4e3a28068be130..a4be180c0088fc 100644 --- a/Modules/_zstd/decompressor.c +++ b/Modules/_zstd/decompressor.c @@ -84,7 +84,7 @@ _PyZstd_set_d_parameters(ZstdDecompressor *self, PyObject *options) if (Py_TYPE(key) == mod_state->CParameter_type) { PyErr_SetString(PyExc_TypeError, "Key of decompression options dict should " - "NOT be CParameter."); + "NOT be CompressionParameter."); return -1; } @@ -92,7 +92,7 @@ _PyZstd_set_d_parameters(ZstdDecompressor *self, PyObject *options) int key_v = PyLong_AsInt(key); if (key_v == -1 && PyErr_Occurred()) { PyErr_SetString(PyExc_ValueError, - "Key of options dict should be a DParameter attribute."); + "Key of options dict should be a DecompressionParameter attribute."); return -1; } From fa0cb0c0a05a290764584ae0140eb9b55b73496f Mon Sep 17 00:00:00 2001 From: Adam Turner <9087854+aa-turner@users.noreply.github.com> Date: Mon, 5 May 2025 16:27:09 +0100 Subject: [PATCH 44/55] regen clinic --- Modules/_zstd/_zstdmodule.c | 4 ++-- Modules/_zstd/clinic/_zstdmodule.c.h | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/Modules/_zstd/_zstdmodule.c b/Modules/_zstd/_zstdmodule.c index a07deeee85a95a..d8c85840a11e17 100644 --- a/Modules/_zstd/_zstdmodule.c +++ b/Modules/_zstd/_zstdmodule.c @@ -412,7 +412,7 @@ Internal function, get CompressionParameter/DecompressionParameter bounds. static PyObject * _zstd__get_param_bounds_impl(PyObject *module, int is_compress, int parameter) -/*[clinic end generated code: output=b751dc710f89ef55 input=fb21ff96aff65df1]*/ +/*[clinic end generated code: output=b751dc710f89ef55 input=1aae4fcf8faf4e0f]*/ { ZSTD_bounds bound; if (is_compress) { @@ -524,7 +524,7 @@ Internal function, set CompressionParameter/DecompressionParameter types for val static PyObject * _zstd__set_parameter_types_impl(PyObject *module, PyObject *c_parameter_type, PyObject *d_parameter_type) -/*[clinic end generated code: output=a13d4890ccbd2873 input=3e7d0d37c3a1045a]*/ +/*[clinic end generated code: output=a13d4890ccbd2873 input=4535545d903853d3]*/ { _zstd_state* const mod_state = get_zstd_state(module); diff --git a/Modules/_zstd/clinic/_zstdmodule.c.h b/Modules/_zstd/clinic/_zstdmodule.c.h index 94f14ed858cdae..1b69150c4f40ca 100644 --- a/Modules/_zstd/clinic/_zstdmodule.c.h +++ b/Modules/_zstd/clinic/_zstdmodule.c.h @@ -149,10 +149,10 @@ PyDoc_STRVAR(_zstd__get_param_bounds__doc__, "_get_param_bounds($module, /, is_compress, parameter)\n" "--\n" "\n" -"Internal function, get CParameter/DParameter bounds.\n" +"Internal function, get CompressionParameter/DecompressionParameter bounds.\n" "\n" " is_compress\n" -" True for CParameter, False for DParameter.\n" +" True for CompressionParameter, False for DecompressionParameter.\n" " parameter\n" " The parameter to get bounds."); @@ -360,12 +360,12 @@ PyDoc_STRVAR(_zstd__set_parameter_types__doc__, "_set_parameter_types($module, /, c_parameter_type, d_parameter_type)\n" "--\n" "\n" -"Internal function, set CParameter/DParameter types for validity check.\n" +"Internal function, set CompressionParameter/DecompressionParameter types for validity check.\n" "\n" " c_parameter_type\n" -" CParameter IntEnum type object\n" +" CompressionParameter IntEnum type object\n" " d_parameter_type\n" -" DParameter IntEnum type object"); +" DecompressionParameter IntEnum type object"); #define _ZSTD__SET_PARAMETER_TYPES_METHODDEF \ {"_set_parameter_types", _PyCFunction_CAST(_zstd__set_parameter_types), METH_FASTCALL|METH_KEYWORDS, _zstd__set_parameter_types__doc__}, @@ -429,4 +429,4 @@ _zstd__set_parameter_types(PyObject *module, PyObject *const *args, Py_ssize_t n exit: return return_value; } -/*[clinic end generated code: output=f4530f3e3439cbe7 input=a9049054013a1b77]*/ +/*[clinic end generated code: output=af9311ffc4321e98 input=a9049054013a1b77]*/ From 74e4d2b69814abf885622cfb4ebf313846d917f7 Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Mon, 5 May 2025 08:29:43 -0700 Subject: [PATCH 45/55] Fix whitespace issue --- Lib/test/test_zstd.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Lib/test/test_zstd.py b/Lib/test/test_zstd.py index 24d4190872054e..f4a25376e5234a 100644 --- a/Lib/test/test_zstd.py +++ b/Lib/test/test_zstd.py @@ -1421,11 +1421,11 @@ def test_init_bad_mode(self): ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rw") with self.assertRaisesRegex(TypeError, r"NOT be CompressionParameter"): - ZstdFile(io.BytesIO(), 'rb', + ZstdFile(io.BytesIO(), 'rb', options={CompressionParameter.compression_level:5}) with self.assertRaisesRegex(TypeError, r"NOT be DecompressionParameter"): - ZstdFile(io.BytesIO(), 'wb', + ZstdFile(io.BytesIO(), 'wb', options={DecompressionParameter.window_log_max:21}) with self.assertRaises(TypeError): From 03fff3dc97e164b5e0df0dce60cbd1a93de50e96 Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Mon, 5 May 2025 09:05:37 -0700 Subject: [PATCH 46/55] Remove makefile test dir --- Makefile.pre.in | 1 - 1 file changed, 1 deletion(-) diff --git a/Makefile.pre.in b/Makefile.pre.in index 381b5f46312da7..17e0c9904cc3aa 100644 --- a/Makefile.pre.in +++ b/Makefile.pre.in @@ -2677,7 +2677,6 @@ TESTSUBDIRS= idlelib/idle_test \ test/test_zipfile/_path \ test/test_zoneinfo \ test/test_zoneinfo/data \ - test/test_zstd \ test/tkinterdata \ test/tokenizedata \ test/tracedmodules \ From a99c5dd67375e35f7b2912348a5b6fd22af0be58 Mon Sep 17 00:00:00 2001 From: Adam Turner <9087854+aa-turner@users.noreply.github.com> Date: Mon, 5 May 2025 11:32:35 -0700 Subject: [PATCH 47/55] swap order of parameters in _get_param_bounds --- Lib/compression/zstd/__init__.py | 4 ++-- Modules/_zstd/_zstdmodule.c | 10 +++++----- Modules/_zstd/clinic/_zstdmodule.c.h | 30 ++++++++++++++-------------- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py index bd95105c44f5a0..a6e20cbdb133cf 100644 --- a/Lib/compression/zstd/__init__.py +++ b/Lib/compression/zstd/__init__.py @@ -212,7 +212,7 @@ def bounds(self): Both lower and upper bounds are inclusive. """ - return _zstd._get_param_bounds(is_compress=True, parameter=self.value) + return _zstd._get_param_bounds(self.value, is_compress=True) class DecompressionParameter(enum.IntEnum): @@ -226,7 +226,7 @@ def bounds(self): Both lower and upper bounds are inclusive. """ - return _zstd._get_param_bounds(is_compress=False, parameter=self.value) + return _zstd._get_param_bounds(self.value, is_compress=False) class Strategy(enum.IntEnum): diff --git a/Modules/_zstd/_zstdmodule.c b/Modules/_zstd/_zstdmodule.c index d8c85840a11e17..81e950b6c19d34 100644 --- a/Modules/_zstd/_zstdmodule.c +++ b/Modules/_zstd/_zstdmodule.c @@ -401,18 +401,18 @@ _zstd__finalize_dict_impl(PyObject *module, PyBytesObject *custom_dict_bytes, /*[clinic input] _zstd._get_param_bounds - is_compress: bool - True for CompressionParameter, False for DecompressionParameter. parameter: int The parameter to get bounds. + is_compress: bool + True for CompressionParameter, False for DecompressionParameter. Internal function, get CompressionParameter/DecompressionParameter bounds. [clinic start generated code]*/ static PyObject * -_zstd__get_param_bounds_impl(PyObject *module, int is_compress, - int parameter) -/*[clinic end generated code: output=b751dc710f89ef55 input=1aae4fcf8faf4e0f]*/ +_zstd__get_param_bounds_impl(PyObject *module, int parameter, + int is_compress) +/*[clinic end generated code: output=9892cd822f937e79 input=884cd1a01125267d]*/ { ZSTD_bounds bound; if (is_compress) { diff --git a/Modules/_zstd/clinic/_zstdmodule.c.h b/Modules/_zstd/clinic/_zstdmodule.c.h index 1b69150c4f40ca..2f8225389b7aea 100644 --- a/Modules/_zstd/clinic/_zstdmodule.c.h +++ b/Modules/_zstd/clinic/_zstdmodule.c.h @@ -146,22 +146,22 @@ _zstd__finalize_dict(PyObject *module, PyObject *const *args, Py_ssize_t nargs) } PyDoc_STRVAR(_zstd__get_param_bounds__doc__, -"_get_param_bounds($module, /, is_compress, parameter)\n" +"_get_param_bounds($module, /, parameter, is_compress)\n" "--\n" "\n" "Internal function, get CompressionParameter/DecompressionParameter bounds.\n" "\n" -" is_compress\n" -" True for CompressionParameter, False for DecompressionParameter.\n" " parameter\n" -" The parameter to get bounds."); +" The parameter to get bounds.\n" +" is_compress\n" +" True for CompressionParameter, False for DecompressionParameter."); #define _ZSTD__GET_PARAM_BOUNDS_METHODDEF \ {"_get_param_bounds", _PyCFunction_CAST(_zstd__get_param_bounds), METH_FASTCALL|METH_KEYWORDS, _zstd__get_param_bounds__doc__}, static PyObject * -_zstd__get_param_bounds_impl(PyObject *module, int is_compress, - int parameter); +_zstd__get_param_bounds_impl(PyObject *module, int parameter, + int is_compress); static PyObject * _zstd__get_param_bounds(PyObject *module, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) @@ -178,7 +178,7 @@ _zstd__get_param_bounds(PyObject *module, PyObject *const *args, Py_ssize_t narg } _kwtuple = { .ob_base = PyVarObject_HEAD_INIT(&PyTuple_Type, NUM_KEYWORDS) .ob_hash = -1, - .ob_item = { &_Py_ID(is_compress), &_Py_ID(parameter), }, + .ob_item = { &_Py_ID(parameter), &_Py_ID(is_compress), }, }; #undef NUM_KEYWORDS #define KWTUPLE (&_kwtuple.ob_base.ob_base) @@ -187,7 +187,7 @@ _zstd__get_param_bounds(PyObject *module, PyObject *const *args, Py_ssize_t narg # define KWTUPLE NULL #endif // !Py_BUILD_CORE - static const char * const _keywords[] = {"is_compress", "parameter", NULL}; + static const char * const _keywords[] = {"parameter", "is_compress", NULL}; static _PyArg_Parser _parser = { .keywords = _keywords, .fname = "_get_param_bounds", @@ -195,23 +195,23 @@ _zstd__get_param_bounds(PyObject *module, PyObject *const *args, Py_ssize_t narg }; #undef KWTUPLE PyObject *argsbuf[2]; - int is_compress; int parameter; + int is_compress; args = _PyArg_UnpackKeywords(args, nargs, NULL, kwnames, &_parser, /*minpos*/ 2, /*maxpos*/ 2, /*minkw*/ 0, /*varpos*/ 0, argsbuf); if (!args) { goto exit; } - is_compress = PyObject_IsTrue(args[0]); - if (is_compress < 0) { + parameter = PyLong_AsInt(args[0]); + if (parameter == -1 && PyErr_Occurred()) { goto exit; } - parameter = PyLong_AsInt(args[1]); - if (parameter == -1 && PyErr_Occurred()) { + is_compress = PyObject_IsTrue(args[1]); + if (is_compress < 0) { goto exit; } - return_value = _zstd__get_param_bounds_impl(module, is_compress, parameter); + return_value = _zstd__get_param_bounds_impl(module, parameter, is_compress); exit: return return_value; @@ -429,4 +429,4 @@ _zstd__set_parameter_types(PyObject *module, PyObject *const *args, Py_ssize_t n exit: return return_value; } -/*[clinic end generated code: output=af9311ffc4321e98 input=a9049054013a1b77]*/ +/*[clinic end generated code: output=189c462236a7096c input=a9049054013a1b77]*/ From bf94aad387a3964bd83eb2190dc79e4a02efe607 Mon Sep 17 00:00:00 2001 From: Adam Turner <9087854+aa-turner@users.noreply.github.com> Date: Tue, 6 May 2025 00:44:24 +0100 Subject: [PATCH 48/55] Sort imports --- Lib/compression/zstd/__init__.py | 7 ++----- Lib/compression/zstd/_zstdfile.py | 2 -- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py index a6e20cbdb133cf..a638a99fc6b76b 100644 --- a/Lib/compression/zstd/__init__.py +++ b/Lib/compression/zstd/__init__.py @@ -26,13 +26,10 @@ "ZstdError", ) +import _zstd import enum - -from compression.zstd._zstdfile import ZstdFile, open, _nbytes from _zstd import * - -import _zstd - +from compression.zstd._zstdfile import ZstdFile, open, _nbytes _ZSTD_CStreamSizes = _zstd._ZSTD_CStreamSizes _ZSTD_DStreamSizes = _zstd._ZSTD_DStreamSizes diff --git a/Lib/compression/zstd/_zstdfile.py b/Lib/compression/zstd/_zstdfile.py index 366a3e48930f33..df7c6d0a5fbbc6 100644 --- a/Lib/compression/zstd/_zstdfile.py +++ b/Lib/compression/zstd/_zstdfile.py @@ -1,7 +1,5 @@ import io - from os import PathLike - from _zstd import (ZstdCompressor, ZstdDecompressor, _ZSTD_DStreamSizes, ZstdError) from compression._common import _streams From 5b45ec7d1f6cd461cacc5d30e1d02ada82bc9aaa Mon Sep 17 00:00:00 2001 From: Adam Turner <9087854+aa-turner@users.noreply.github.com> Date: Tue, 6 May 2025 00:46:16 +0100 Subject: [PATCH 49/55] Improve docstrings --- Lib/compression/zstd/__init__.py | 57 ++++++++++++++----------------- Lib/compression/zstd/_zstdfile.py | 31 ++++++++--------- 2 files changed, 39 insertions(+), 49 deletions(-) diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py index a638a99fc6b76b..477ab525fe3c87 100644 --- a/Lib/compression/zstd/__init__.py +++ b/Lib/compression/zstd/__init__.py @@ -38,6 +38,7 @@ COMPRESSION_LEVEL_DEFAULT = _zstd._compressionLevel_values[0] +"""The default compression level for Zstandard, currently '3'.""" class FrameInfo: @@ -57,25 +58,24 @@ def __setattr__(self, name, _): def get_frame_info(frame_buffer): - """Get zstd frame information from a frame header. + """Get Zstandard frame information from a frame header. - *frame_buffer* is a bytes-like object. It should starts from the beginning + *frame_buffer* is a bytes-like object. It should start from the beginning of a frame, and needs to include at least the frame header (6 to 18 bytes). - Return a FrameInfo dataclass, which currently has two attributes - + The returned FrameInfo object has two attributes. 'decompressed_size' is the size in bytes of the data in the frame when decompressed, or None when the decompressed size is unknown. - - 'dictionary_id' is a 32-bit unsigned integer value. 0 means dictionary ID - was not recorded in the frame header, the frame may or may not need a - dictionary to be decoded, and the ID of such a dictionary is not specified. + 'dictionary_id' is an int in the range (0, 2**32). The special value 0 + means that the dictionary ID was not recorded in the frame header, + the frame may or may not need a dictionary to be decoded, + and the ID of such a dictionary is not specified. """ return FrameInfo(*_zstd._get_frame_info(frame_buffer)) def train_dict(samples, dict_size): - """Train a zstd dictionary, return a ZstdDict object. + """Return a ZstdDict representing a trained Zstandard dictionary. *samples* is an iterable of samples, where a sample is a bytes-like object representing a file. @@ -97,24 +97,22 @@ def train_dict(samples, dict_size): def finalize_dict(zstd_dict, samples, dict_size, level): - """Finalize a zstd dictionary, return a ZstdDict object. + """Return a ZstdDict representing a finalized Zstandard dictionary. Given a custom content as a basis for dictionary, and a set of samples, - finalize dictionary by adding headers and statistics according to the zstd - dictionary format. + finalize *zstd_dict* by adding headers and statistics according to the + Zstandard dictionary format. You may compose an effective dictionary content by hand, which is used as basis dictionary, and use some samples to finalize a dictionary. The basis - dictionary can be a "raw content" dictionary, see is_raw parameter in - ZstdDict.__init__ method. + dictionary may be a "raw content" dictionary. See *is_raw* in ZstdDict. - *zstd_dict* is a ZstdDict object, the basis dictionary. - *samples* is an iterable of samples, a sample is a bytes-like object + *samples* is an iterable of samples, where a sample is a bytes-like object representing a file. *dict_size* is the dictionary's maximum size, in bytes. - *level* is the compression level expected to use in production. The - statistics for each compression level differ, so tuning the - dictionary for the compression level can help quite a bit. + *level* is the expected compression level. The statistics for each + compression level differ, so tuning the dictionary to the compression level + can provide improvements. """ # Check arguments' type @@ -140,11 +138,9 @@ def compress(data, level=None, options=None, zstd_dict=None): """Return Zstandard compressed *data* as bytes. *level* is an int specifying the compression level to use, defaulting to - COMPRESSION_LEVEL_DEFAULT - + COMPRESSION_LEVEL_DEFAULT ('3'). *options* is a dict object that contains advanced compression parameters. See CompressionParameter for more on options. - *zstd_dict* is a ZstdDict object, a pre-trained Zstandard dictionary. See the function train_dict for how to train a ZstdDict on sample data. @@ -158,7 +154,6 @@ def decompress(data, zstd_dict=None, options=None): *zstd_dict* is a ZstdDict object, a pre-trained Zstandard dictionary. See the function train_dict for how to train a ZstdDict on sample data. - *options* is a dict object that contains advanced compression parameters. See DecompressionParameter for more on options. @@ -178,7 +173,7 @@ def decompress(data, zstd_dict=None, options=None): class CompressionParameter(enum.IntEnum): - """Compression parameters""" + """Compression parameters.""" compression_level = _zstd._ZSTD_c_compressionLevel window_log = _zstd._ZSTD_c_windowLog @@ -204,24 +199,22 @@ class CompressionParameter(enum.IntEnum): overlap_log = _zstd._ZSTD_c_overlapLog def bounds(self): - """Returns a tuple of ints (lower, upper), representing the bounds of a - compression parameter. + """Return the (lower, upper) int bounds of a compression parameter. - Both lower and upper bounds are inclusive. + Both the lower and upper bounds are inclusive. """ return _zstd._get_param_bounds(self.value, is_compress=True) class DecompressionParameter(enum.IntEnum): - """Decompression parameters""" + """Decompression parameters.""" window_log_max = _zstd._ZSTD_d_windowLogMax def bounds(self): - """Returns a tuple of ints (lower, upper) representing the bounds of a - decompression parameter. + """Return the (lower, upper) int bounds of a decompression parameter. - Both lower and upper bounds are inclusive. + Both the lower and upper bounds are inclusive. """ return _zstd._get_param_bounds(self.value, is_compress=False) @@ -245,5 +238,5 @@ class Strategy(enum.IntEnum): btultra2 = _zstd._ZSTD_btultra2 -# Set CompressionParameter/DecompressionParameter types for validity check +# Check validity of the CompressionParameter & DecompressionParameter types _zstd._set_parameter_types(CompressionParameter, DecompressionParameter) diff --git a/Lib/compression/zstd/_zstdfile.py b/Lib/compression/zstd/_zstdfile.py index df7c6d0a5fbbc6..5264df6f0c90ac 100644 --- a/Lib/compression/zstd/_zstdfile.py +++ b/Lib/compression/zstd/_zstdfile.py @@ -21,14 +21,13 @@ def _nbytes(dat): class ZstdFile(_streams.BaseStream): - """A file object providing transparent zstd (de)compression. + """A file-like object providing transparent Zstandard (de)compression. A ZstdFile can act as a wrapper for an existing file object, or refer directly to a named file on disk. - Note that ZstdFile provides a *binary* file interface - data read is - returned as bytes, and data to be written should be an object that - supports the Buffer Protocol. + ZstdFile provides a *binary* file interface. Data is read and returned as + bytes, and may only be written to objects that support the Buffer Protocol. """ FLUSH_BLOCK = ZstdCompressor.FLUSH_BLOCK @@ -44,25 +43,23 @@ def __init__( options=None, zstd_dict=None, ): - """Open a zstd compressed file in binary mode. + """Open a Zstandard compressed file in binary mode. - file can be either an actual file name (given as a str, bytes, or - PathLike object), in which case the named file is opened, or it can be - an existing file object to read from or write to. + *file* can be either an file-like object, or a file name to open. - mode can be "r" for reading (default), "w" for (over)writing, "x" for + *mode* can be "r" for reading (default), "w" for (over)writing, "x" for creating exclusively, or "a" for appending. These can equivalently be given as "rb", "wb", "xb" and "ab" respectively. - level is an int specifying the compression level to use, defaulting to - COMPRESSION_LEVEL_DEFAULT + *level* is an optional int specifying the compression level to use, + or COMPRESSION_LEVEL_DEFAULT if not given. - options is a dict object that contains advanced compression - parameters. See CompressionParameter or DecompressionParameter for more - information on options. + *options* is an optional dict for advanced compression parameters. + See CompressionParameter and DecompressionParameter for the possible + options. - zstd_dict is a ZstdDict object, a pre-trained Zstandard dictionary. - See the function train_dict for how to train a ZstdDict on sample data. + *zstd_dict* is an optional ZstdDict object, a pre-trained Zstandard + dictionary. See train_dict() to train ZstdDict on sample data. """ self._fp = None self._close_fp = False @@ -323,7 +320,7 @@ def open( errors=None, newline=None, ): - """Open a zstd compressed file in binary or text mode. + """Open a Zstandard compressed file in binary or text mode. file can be either a file name (given as a str, bytes, or PathLike object), in which case the named file is opened, or it can be an existing file object From b0eca5a9259744200be8b605d896c50a186752cb Mon Sep 17 00:00:00 2001 From: Adam Turner <9087854+aa-turner@users.noreply.github.com> Date: Tue, 6 May 2025 00:47:30 +0100 Subject: [PATCH 50/55] Remove comments --- Lib/compression/zstd/__init__.py | 1 - Lib/compression/zstd/_zstdfile.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py index 477ab525fe3c87..0ad545486d6421 100644 --- a/Lib/compression/zstd/__init__.py +++ b/Lib/compression/zstd/__init__.py @@ -115,7 +115,6 @@ def finalize_dict(zstd_dict, samples, dict_size, level): can provide improvements. """ - # Check arguments' type if not isinstance(zstd_dict, ZstdDict): raise TypeError('zstd_dict argument should be a ZstdDict object.') if not isinstance(dict_size, int): diff --git a/Lib/compression/zstd/_zstdfile.py b/Lib/compression/zstd/_zstdfile.py index 5264df6f0c90ac..775e063d1f3344 100644 --- a/Lib/compression/zstd/_zstdfile.py +++ b/Lib/compression/zstd/_zstdfile.py @@ -68,7 +68,6 @@ def __init__( if not isinstance(mode, str): raise ValueError("mode must be a str") - # Read or write mode if options is not None and not isinstance(options, dict): raise TypeError("options must be a dict or None") mode = mode.removesuffix("b") # handle rb, wb, xb, ab @@ -87,7 +86,6 @@ def __init__( else: raise ValueError(f"Invalid mode: {mode!r}") - # File object if isinstance(file, (str, bytes, PathLike)): self._fp = io.open(file, f'{mode}b') self._close_fp = True From c0d0e10fe3fa3578f8c635f48aa09a3bc76f54bb Mon Sep 17 00:00:00 2001 From: Adam Turner <9087854+aa-turner@users.noreply.github.com> Date: Tue, 6 May 2025 00:48:00 +0100 Subject: [PATCH 51/55] Remove unused private variables --- Lib/compression/zstd/__init__.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py index 0ad545486d6421..9296c141a24597 100644 --- a/Lib/compression/zstd/__init__.py +++ b/Lib/compression/zstd/__init__.py @@ -31,12 +31,6 @@ from _zstd import * from compression.zstd._zstdfile import ZstdFile, open, _nbytes -_ZSTD_CStreamSizes = _zstd._ZSTD_CStreamSizes -_ZSTD_DStreamSizes = _zstd._ZSTD_DStreamSizes -_train_dict = _zstd._train_dict -_finalize_dict = _zstd._finalize_dict - - COMPRESSION_LEVEL_DEFAULT = _zstd._compressionLevel_values[0] """The default compression level for Zstandard, currently '3'.""" @@ -91,8 +85,7 @@ def train_dict(samples, dict_size): chunk_sizes = tuple(_nbytes(sample) for sample in samples) if not chunks: raise ValueError("samples contained no data; can't train dictionary.") - dict_content = _train_dict(chunks, chunk_sizes, dict_size) - + dict_content = _zstd._train_dict(chunks, chunk_sizes, dict_size) return ZstdDict(dict_content) @@ -128,9 +121,9 @@ def finalize_dict(zstd_dict, samples, dict_size, level): if not chunks: raise ValueError("The samples are empty content, can't finalize" "dictionary.") - dict_content = _finalize_dict(zstd_dict.dict_content, - chunks, chunk_sizes, - dict_size, level) + dict_content = _zstd._finalize_dict(zstd_dict.dict_content, + chunks, chunk_sizes, + dict_size, level) return ZstdDict(dict_content) def compress(data, level=None, options=None, zstd_dict=None): From 10f0cffbc04eea29fb7a88932a257cb2304d0971 Mon Sep 17 00:00:00 2001 From: Adam Turner <9087854+aa-turner@users.noreply.github.com> Date: Tue, 6 May 2025 00:48:41 +0100 Subject: [PATCH 52/55] Misc changes (positional-only, style, error messages) --- Lib/compression/zstd/__init__.py | 4 +-- Lib/compression/zstd/_zstdfile.py | 49 ++++++++++--------------------- 2 files changed, 17 insertions(+), 36 deletions(-) diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py index 9296c141a24597..4f734eb07b00e3 100644 --- a/Lib/compression/zstd/__init__.py +++ b/Lib/compression/zstd/__init__.py @@ -89,7 +89,7 @@ def train_dict(samples, dict_size): return ZstdDict(dict_content) -def finalize_dict(zstd_dict, samples, dict_size, level): +def finalize_dict(zstd_dict, /, samples, dict_size, level): """Return a ZstdDict representing a finalized Zstandard dictionary. Given a custom content as a basis for dictionary, and a set of samples, @@ -119,7 +119,7 @@ def finalize_dict(zstd_dict, samples, dict_size, level): chunks = b''.join(samples) chunk_sizes = tuple(_nbytes(sample) for sample in samples) if not chunks: - raise ValueError("The samples are empty content, can't finalize" + raise ValueError("The samples are empty content, can't finalize the" "dictionary.") dict_content = _zstd._finalize_dict(zstd_dict.dict_content, chunks, chunk_sizes, diff --git a/Lib/compression/zstd/_zstdfile.py b/Lib/compression/zstd/_zstdfile.py index 775e063d1f3344..fbc9e02a733626 100644 --- a/Lib/compression/zstd/_zstdfile.py +++ b/Lib/compression/zstd/_zstdfile.py @@ -13,7 +13,7 @@ _MODE_WRITE = 2 -def _nbytes(dat): +def _nbytes(dat, /): if isinstance(dat, (bytes, bytearray)): return len(dat) with memoryview(dat) as mv: @@ -33,16 +33,8 @@ class ZstdFile(_streams.BaseStream): FLUSH_BLOCK = ZstdCompressor.FLUSH_BLOCK FLUSH_FRAME = ZstdCompressor.FLUSH_FRAME - def __init__( - self, - file, - /, - mode="r", - *, - level=None, - options=None, - zstd_dict=None, - ): + def __init__(self, file, /, mode="r", *, + level=None, options=None, zstd_dict=None): """Open a Zstandard compressed file in binary mode. *file* can be either an file-like object, or a file name to open. @@ -79,9 +71,8 @@ def __init__( if level is not None and not isinstance(level, int): raise TypeError("level must be int or None") self._mode = _MODE_WRITE - self._compressor = ZstdCompressor( - level=level, options=options, zstd_dict=zstd_dict - ) + self._compressor = ZstdCompressor(level=level, options=options, + zstd_dict=zstd_dict) self._pos = 0 else: raise ValueError(f"Invalid mode: {mode!r}") @@ -131,7 +122,7 @@ def close(self): self._fp = None self._close_fp = False - def write(self, data): + def write(self, data, /): """Write a bytes-like object *data* to the file. Returns the number of uncompressed bytes written, which is @@ -305,19 +296,8 @@ def writable(self): return self._mode == _MODE_WRITE -# Copied from lzma module -def open( - file, - /, - mode="rb", - *, - level=None, - options=None, - zstd_dict=None, - encoding=None, - errors=None, - newline=None, -): +def open(file, /, mode="rb", *, level=None, options=None, zstd_dict=None, + encoding=None, errors=None, newline=None): """Open a Zstandard compressed file in binary or text mode. file can be either a file name (given as a str, bytes, or PathLike object), @@ -346,7 +326,10 @@ def open( behavior, and line ending(s). """ - if "t" in mode: + text_mode = "t" in mode + mode = mode.replace("t", "") + + if text_mode: if "b" in mode: raise ValueError(f"Invalid mode: {mode!r}") else: @@ -357,12 +340,10 @@ def open( if newline is not None: raise ValueError("Argument 'newline' not supported in binary mode") - zstd_mode = mode.replace("t", "") - binary_file = ZstdFile( - file, zstd_mode, level=level, options=options, zstd_dict=zstd_dict - ) + binary_file = ZstdFile(file, mode, level=level, options=options, + zstd_dict=zstd_dict) - if "t" in mode: + if text_mode: return io.TextIOWrapper(binary_file, encoding, errors, newline) else: return binary_file From 7f8c350fbd680a831d449c694e0d5056edbed34d Mon Sep 17 00:00:00 2001 From: Adam Turner <9087854+aa-turner@users.noreply.github.com> Date: Tue, 6 May 2025 00:48:46 +0100 Subject: [PATCH 53/55] whitespace --- Modules/_zstd/_zstdmodule.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Modules/_zstd/_zstdmodule.c b/Modules/_zstd/_zstdmodule.c index 81e950b6c19d34..4d046859a1540e 100644 --- a/Modules/_zstd/_zstdmodule.c +++ b/Modules/_zstd/_zstdmodule.c @@ -537,7 +537,7 @@ _zstd__set_parameter_types_impl(PyObject *module, PyObject *c_parameter_type, Py_XDECREF(mod_state->CParameter_type); Py_INCREF(c_parameter_type); - mod_state->CParameter_type = (PyTypeObject*) c_parameter_type; + mod_state->CParameter_type = (PyTypeObject*)c_parameter_type; Py_XDECREF(mod_state->DParameter_type); Py_INCREF(d_parameter_type); From bf4b07daade1a55da151db571fc581d80899e8db Mon Sep 17 00:00:00 2001 From: Adam Turner <9087854+aa-turner@users.noreply.github.com> Date: Tue, 6 May 2025 00:51:35 +0100 Subject: [PATCH 54/55] Remove _set_parameter_types Python, in general, works on protocols rather than nominal or specific types. I think this check is overly restrictive, with only minor benefit. Users wanting validation can use static type checkers. --- Lib/compression/zstd/__init__.py | 4 -- Modules/_zstd/_zstdmodule.c | 46 ----------------- Modules/_zstd/_zstdmodule.h | 3 -- Modules/_zstd/clinic/_zstdmodule.c.h | 76 +--------------------------- Modules/_zstd/compressor.c | 8 --- Modules/_zstd/decompressor.c | 8 --- 6 files changed, 1 insertion(+), 144 deletions(-) diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py index 4f734eb07b00e3..0dd83b846a2b03 100644 --- a/Lib/compression/zstd/__init__.py +++ b/Lib/compression/zstd/__init__.py @@ -228,7 +228,3 @@ class Strategy(enum.IntEnum): btopt = _zstd._ZSTD_btopt btultra = _zstd._ZSTD_btultra btultra2 = _zstd._ZSTD_btultra2 - - -# Check validity of the CompressionParameter & DecompressionParameter types -_zstd._set_parameter_types(CompressionParameter, DecompressionParameter) diff --git a/Modules/_zstd/_zstdmodule.c b/Modules/_zstd/_zstdmodule.c index 4d046859a1540e..f881cefa80a2ed 100644 --- a/Modules/_zstd/_zstdmodule.c +++ b/Modules/_zstd/_zstdmodule.c @@ -510,49 +510,12 @@ _zstd__get_frame_info_impl(PyObject *module, Py_buffer *frame_buffer) return Py_BuildValue("KI", decompressed_size, dict_id); } -/*[clinic input] -_zstd._set_parameter_types - - c_parameter_type: object(subclass_of='&PyType_Type') - CompressionParameter IntEnum type object - d_parameter_type: object(subclass_of='&PyType_Type') - DecompressionParameter IntEnum type object - -Internal function, set CompressionParameter/DecompressionParameter types for validity check. -[clinic start generated code]*/ - -static PyObject * -_zstd__set_parameter_types_impl(PyObject *module, PyObject *c_parameter_type, - PyObject *d_parameter_type) -/*[clinic end generated code: output=a13d4890ccbd2873 input=4535545d903853d3]*/ -{ - _zstd_state* const mod_state = get_zstd_state(module); - - if (!PyType_Check(c_parameter_type) || !PyType_Check(d_parameter_type)) { - PyErr_SetString(PyExc_ValueError, - "The two arguments should be CompressionParameter and " - "DecompressionParameter types."); - return NULL; - } - - Py_XDECREF(mod_state->CParameter_type); - Py_INCREF(c_parameter_type); - mod_state->CParameter_type = (PyTypeObject*)c_parameter_type; - - Py_XDECREF(mod_state->DParameter_type); - Py_INCREF(d_parameter_type); - mod_state->DParameter_type = (PyTypeObject*)d_parameter_type; - - Py_RETURN_NONE; -} - static PyMethodDef _zstd_methods[] = { _ZSTD__TRAIN_DICT_METHODDEF _ZSTD__FINALIZE_DICT_METHODDEF _ZSTD__GET_PARAM_BOUNDS_METHODDEF _ZSTD_GET_FRAME_SIZE_METHODDEF _ZSTD__GET_FRAME_INFO_METHODDEF - _ZSTD__SET_PARAMETER_TYPES_METHODDEF {0} }; @@ -766,9 +729,6 @@ static int _zstd_exec(PyObject *module) { ADD_STR_TO_STATE_MACRO(write); ADD_STR_TO_STATE_MACRO(flush); - mod_state->CParameter_type = NULL; - mod_state->DParameter_type = NULL; - /* Add variables to module */ if (add_vars_to_module(module) < 0) { return -1; @@ -852,9 +812,6 @@ _zstd_traverse(PyObject *module, visitproc visit, void *arg) Py_VISIT(mod_state->ZstdDecompressor_type); Py_VISIT(mod_state->ZstdError); - - Py_VISIT(mod_state->CParameter_type); - Py_VISIT(mod_state->DParameter_type); return 0; } @@ -876,9 +833,6 @@ _zstd_clear(PyObject *module) Py_CLEAR(mod_state->ZstdDecompressor_type); Py_CLEAR(mod_state->ZstdError); - - Py_CLEAR(mod_state->CParameter_type); - Py_CLEAR(mod_state->DParameter_type); return 0; } diff --git a/Modules/_zstd/_zstdmodule.h b/Modules/_zstd/_zstdmodule.h index d50f1489e6f574..d7c9e69eb602f7 100644 --- a/Modules/_zstd/_zstdmodule.h +++ b/Modules/_zstd/_zstdmodule.h @@ -54,9 +54,6 @@ struct _zstd_state { PyTypeObject *ZstdCompressor_type; PyTypeObject *ZstdDecompressor_type; PyObject *ZstdError; - - PyTypeObject *CParameter_type; - PyTypeObject *DParameter_type; }; typedef struct { diff --git a/Modules/_zstd/clinic/_zstdmodule.c.h b/Modules/_zstd/clinic/_zstdmodule.c.h index 2f8225389b7aea..9e04c39b955754 100644 --- a/Modules/_zstd/clinic/_zstdmodule.c.h +++ b/Modules/_zstd/clinic/_zstdmodule.c.h @@ -355,78 +355,4 @@ _zstd__get_frame_info(PyObject *module, PyObject *const *args, Py_ssize_t nargs, return return_value; } - -PyDoc_STRVAR(_zstd__set_parameter_types__doc__, -"_set_parameter_types($module, /, c_parameter_type, d_parameter_type)\n" -"--\n" -"\n" -"Internal function, set CompressionParameter/DecompressionParameter types for validity check.\n" -"\n" -" c_parameter_type\n" -" CompressionParameter IntEnum type object\n" -" d_parameter_type\n" -" DecompressionParameter IntEnum type object"); - -#define _ZSTD__SET_PARAMETER_TYPES_METHODDEF \ - {"_set_parameter_types", _PyCFunction_CAST(_zstd__set_parameter_types), METH_FASTCALL|METH_KEYWORDS, _zstd__set_parameter_types__doc__}, - -static PyObject * -_zstd__set_parameter_types_impl(PyObject *module, PyObject *c_parameter_type, - PyObject *d_parameter_type); - -static PyObject * -_zstd__set_parameter_types(PyObject *module, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) -{ - PyObject *return_value = NULL; - #if defined(Py_BUILD_CORE) && !defined(Py_BUILD_CORE_MODULE) - - #define NUM_KEYWORDS 2 - static struct { - PyGC_Head _this_is_not_used; - PyObject_VAR_HEAD - Py_hash_t ob_hash; - PyObject *ob_item[NUM_KEYWORDS]; - } _kwtuple = { - .ob_base = PyVarObject_HEAD_INIT(&PyTuple_Type, NUM_KEYWORDS) - .ob_hash = -1, - .ob_item = { &_Py_ID(c_parameter_type), &_Py_ID(d_parameter_type), }, - }; - #undef NUM_KEYWORDS - #define KWTUPLE (&_kwtuple.ob_base.ob_base) - - #else // !Py_BUILD_CORE - # define KWTUPLE NULL - #endif // !Py_BUILD_CORE - - static const char * const _keywords[] = {"c_parameter_type", "d_parameter_type", NULL}; - static _PyArg_Parser _parser = { - .keywords = _keywords, - .fname = "_set_parameter_types", - .kwtuple = KWTUPLE, - }; - #undef KWTUPLE - PyObject *argsbuf[2]; - PyObject *c_parameter_type; - PyObject *d_parameter_type; - - args = _PyArg_UnpackKeywords(args, nargs, NULL, kwnames, &_parser, - /*minpos*/ 2, /*maxpos*/ 2, /*minkw*/ 0, /*varpos*/ 0, argsbuf); - if (!args) { - goto exit; - } - if (!PyObject_TypeCheck(args[0], &PyType_Type)) { - _PyArg_BadArgument("_set_parameter_types", "argument 'c_parameter_type'", (&PyType_Type)->tp_name, args[0]); - goto exit; - } - c_parameter_type = args[0]; - if (!PyObject_TypeCheck(args[1], &PyType_Type)) { - _PyArg_BadArgument("_set_parameter_types", "argument 'd_parameter_type'", (&PyType_Type)->tp_name, args[1]); - goto exit; - } - d_parameter_type = args[1]; - return_value = _zstd__set_parameter_types_impl(module, c_parameter_type, d_parameter_type); - -exit: - return return_value; -} -/*[clinic end generated code: output=189c462236a7096c input=a9049054013a1b77]*/ +/*[clinic end generated code: output=d17e7f5a90a9daca input=a9049054013a1b77]*/ diff --git a/Modules/_zstd/compressor.c b/Modules/_zstd/compressor.c index b735981e7476d5..5b37d660ce7a61 100644 --- a/Modules/_zstd/compressor.c +++ b/Modules/_zstd/compressor.c @@ -67,14 +67,6 @@ _PyZstd_set_c_parameters(ZstdCompressor *self, PyObject *level_or_options, Py_ssize_t pos = 0; while (PyDict_Next(level_or_options, &pos, &key, &value)) { - /* Check key type */ - if (Py_TYPE(key) == mod_state->DParameter_type) { - PyErr_SetString(PyExc_TypeError, - "Key of compression option dict should " - "NOT be DecompressionParameter."); - return -1; - } - int key_v = PyLong_AsInt(key); if (key_v == -1 && PyErr_Occurred()) { PyErr_SetString(PyExc_ValueError, diff --git a/Modules/_zstd/decompressor.c b/Modules/_zstd/decompressor.c index a4be180c0088fc..87bb415ba29f5c 100644 --- a/Modules/_zstd/decompressor.c +++ b/Modules/_zstd/decompressor.c @@ -80,14 +80,6 @@ _PyZstd_set_d_parameters(ZstdDecompressor *self, PyObject *options) pos = 0; while (PyDict_Next(options, &pos, &key, &value)) { - /* Check key type */ - if (Py_TYPE(key) == mod_state->CParameter_type) { - PyErr_SetString(PyExc_TypeError, - "Key of decompression options dict should " - "NOT be CompressionParameter."); - return -1; - } - /* Both key & value should be 32-bit signed int */ int key_v = PyLong_AsInt(key); if (key_v == -1 && PyErr_Occurred()) { From eaf46a8a561817f6eeec9a673c262cb6cf88a9fe Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Mon, 5 May 2025 17:07:49 -0700 Subject: [PATCH 55/55] Revert "Remove _set_parameter_types" This reverts commit bf4b07daade1a55da151db571fc581d80899e8db. Checking the type of parameters is important to avoid confusing error messages. --- Lib/compression/zstd/__init__.py | 4 ++ Modules/_zstd/_zstdmodule.c | 46 +++++++++++++++++ Modules/_zstd/_zstdmodule.h | 3 ++ Modules/_zstd/clinic/_zstdmodule.c.h | 76 +++++++++++++++++++++++++++- Modules/_zstd/compressor.c | 8 +++ Modules/_zstd/decompressor.c | 8 +++ 6 files changed, 144 insertions(+), 1 deletion(-) diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py index 0dd83b846a2b03..4f734eb07b00e3 100644 --- a/Lib/compression/zstd/__init__.py +++ b/Lib/compression/zstd/__init__.py @@ -228,3 +228,7 @@ class Strategy(enum.IntEnum): btopt = _zstd._ZSTD_btopt btultra = _zstd._ZSTD_btultra btultra2 = _zstd._ZSTD_btultra2 + + +# Check validity of the CompressionParameter & DecompressionParameter types +_zstd._set_parameter_types(CompressionParameter, DecompressionParameter) diff --git a/Modules/_zstd/_zstdmodule.c b/Modules/_zstd/_zstdmodule.c index f881cefa80a2ed..4d046859a1540e 100644 --- a/Modules/_zstd/_zstdmodule.c +++ b/Modules/_zstd/_zstdmodule.c @@ -510,12 +510,49 @@ _zstd__get_frame_info_impl(PyObject *module, Py_buffer *frame_buffer) return Py_BuildValue("KI", decompressed_size, dict_id); } +/*[clinic input] +_zstd._set_parameter_types + + c_parameter_type: object(subclass_of='&PyType_Type') + CompressionParameter IntEnum type object + d_parameter_type: object(subclass_of='&PyType_Type') + DecompressionParameter IntEnum type object + +Internal function, set CompressionParameter/DecompressionParameter types for validity check. +[clinic start generated code]*/ + +static PyObject * +_zstd__set_parameter_types_impl(PyObject *module, PyObject *c_parameter_type, + PyObject *d_parameter_type) +/*[clinic end generated code: output=a13d4890ccbd2873 input=4535545d903853d3]*/ +{ + _zstd_state* const mod_state = get_zstd_state(module); + + if (!PyType_Check(c_parameter_type) || !PyType_Check(d_parameter_type)) { + PyErr_SetString(PyExc_ValueError, + "The two arguments should be CompressionParameter and " + "DecompressionParameter types."); + return NULL; + } + + Py_XDECREF(mod_state->CParameter_type); + Py_INCREF(c_parameter_type); + mod_state->CParameter_type = (PyTypeObject*)c_parameter_type; + + Py_XDECREF(mod_state->DParameter_type); + Py_INCREF(d_parameter_type); + mod_state->DParameter_type = (PyTypeObject*)d_parameter_type; + + Py_RETURN_NONE; +} + static PyMethodDef _zstd_methods[] = { _ZSTD__TRAIN_DICT_METHODDEF _ZSTD__FINALIZE_DICT_METHODDEF _ZSTD__GET_PARAM_BOUNDS_METHODDEF _ZSTD_GET_FRAME_SIZE_METHODDEF _ZSTD__GET_FRAME_INFO_METHODDEF + _ZSTD__SET_PARAMETER_TYPES_METHODDEF {0} }; @@ -729,6 +766,9 @@ static int _zstd_exec(PyObject *module) { ADD_STR_TO_STATE_MACRO(write); ADD_STR_TO_STATE_MACRO(flush); + mod_state->CParameter_type = NULL; + mod_state->DParameter_type = NULL; + /* Add variables to module */ if (add_vars_to_module(module) < 0) { return -1; @@ -812,6 +852,9 @@ _zstd_traverse(PyObject *module, visitproc visit, void *arg) Py_VISIT(mod_state->ZstdDecompressor_type); Py_VISIT(mod_state->ZstdError); + + Py_VISIT(mod_state->CParameter_type); + Py_VISIT(mod_state->DParameter_type); return 0; } @@ -833,6 +876,9 @@ _zstd_clear(PyObject *module) Py_CLEAR(mod_state->ZstdDecompressor_type); Py_CLEAR(mod_state->ZstdError); + + Py_CLEAR(mod_state->CParameter_type); + Py_CLEAR(mod_state->DParameter_type); return 0; } diff --git a/Modules/_zstd/_zstdmodule.h b/Modules/_zstd/_zstdmodule.h index d7c9e69eb602f7..d50f1489e6f574 100644 --- a/Modules/_zstd/_zstdmodule.h +++ b/Modules/_zstd/_zstdmodule.h @@ -54,6 +54,9 @@ struct _zstd_state { PyTypeObject *ZstdCompressor_type; PyTypeObject *ZstdDecompressor_type; PyObject *ZstdError; + + PyTypeObject *CParameter_type; + PyTypeObject *DParameter_type; }; typedef struct { diff --git a/Modules/_zstd/clinic/_zstdmodule.c.h b/Modules/_zstd/clinic/_zstdmodule.c.h index 9e04c39b955754..2f8225389b7aea 100644 --- a/Modules/_zstd/clinic/_zstdmodule.c.h +++ b/Modules/_zstd/clinic/_zstdmodule.c.h @@ -355,4 +355,78 @@ _zstd__get_frame_info(PyObject *module, PyObject *const *args, Py_ssize_t nargs, return return_value; } -/*[clinic end generated code: output=d17e7f5a90a9daca input=a9049054013a1b77]*/ + +PyDoc_STRVAR(_zstd__set_parameter_types__doc__, +"_set_parameter_types($module, /, c_parameter_type, d_parameter_type)\n" +"--\n" +"\n" +"Internal function, set CompressionParameter/DecompressionParameter types for validity check.\n" +"\n" +" c_parameter_type\n" +" CompressionParameter IntEnum type object\n" +" d_parameter_type\n" +" DecompressionParameter IntEnum type object"); + +#define _ZSTD__SET_PARAMETER_TYPES_METHODDEF \ + {"_set_parameter_types", _PyCFunction_CAST(_zstd__set_parameter_types), METH_FASTCALL|METH_KEYWORDS, _zstd__set_parameter_types__doc__}, + +static PyObject * +_zstd__set_parameter_types_impl(PyObject *module, PyObject *c_parameter_type, + PyObject *d_parameter_type); + +static PyObject * +_zstd__set_parameter_types(PyObject *module, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) +{ + PyObject *return_value = NULL; + #if defined(Py_BUILD_CORE) && !defined(Py_BUILD_CORE_MODULE) + + #define NUM_KEYWORDS 2 + static struct { + PyGC_Head _this_is_not_used; + PyObject_VAR_HEAD + Py_hash_t ob_hash; + PyObject *ob_item[NUM_KEYWORDS]; + } _kwtuple = { + .ob_base = PyVarObject_HEAD_INIT(&PyTuple_Type, NUM_KEYWORDS) + .ob_hash = -1, + .ob_item = { &_Py_ID(c_parameter_type), &_Py_ID(d_parameter_type), }, + }; + #undef NUM_KEYWORDS + #define KWTUPLE (&_kwtuple.ob_base.ob_base) + + #else // !Py_BUILD_CORE + # define KWTUPLE NULL + #endif // !Py_BUILD_CORE + + static const char * const _keywords[] = {"c_parameter_type", "d_parameter_type", NULL}; + static _PyArg_Parser _parser = { + .keywords = _keywords, + .fname = "_set_parameter_types", + .kwtuple = KWTUPLE, + }; + #undef KWTUPLE + PyObject *argsbuf[2]; + PyObject *c_parameter_type; + PyObject *d_parameter_type; + + args = _PyArg_UnpackKeywords(args, nargs, NULL, kwnames, &_parser, + /*minpos*/ 2, /*maxpos*/ 2, /*minkw*/ 0, /*varpos*/ 0, argsbuf); + if (!args) { + goto exit; + } + if (!PyObject_TypeCheck(args[0], &PyType_Type)) { + _PyArg_BadArgument("_set_parameter_types", "argument 'c_parameter_type'", (&PyType_Type)->tp_name, args[0]); + goto exit; + } + c_parameter_type = args[0]; + if (!PyObject_TypeCheck(args[1], &PyType_Type)) { + _PyArg_BadArgument("_set_parameter_types", "argument 'd_parameter_type'", (&PyType_Type)->tp_name, args[1]); + goto exit; + } + d_parameter_type = args[1]; + return_value = _zstd__set_parameter_types_impl(module, c_parameter_type, d_parameter_type); + +exit: + return return_value; +} +/*[clinic end generated code: output=189c462236a7096c input=a9049054013a1b77]*/ diff --git a/Modules/_zstd/compressor.c b/Modules/_zstd/compressor.c index 5b37d660ce7a61..b735981e7476d5 100644 --- a/Modules/_zstd/compressor.c +++ b/Modules/_zstd/compressor.c @@ -67,6 +67,14 @@ _PyZstd_set_c_parameters(ZstdCompressor *self, PyObject *level_or_options, Py_ssize_t pos = 0; while (PyDict_Next(level_or_options, &pos, &key, &value)) { + /* Check key type */ + if (Py_TYPE(key) == mod_state->DParameter_type) { + PyErr_SetString(PyExc_TypeError, + "Key of compression option dict should " + "NOT be DecompressionParameter."); + return -1; + } + int key_v = PyLong_AsInt(key); if (key_v == -1 && PyErr_Occurred()) { PyErr_SetString(PyExc_ValueError, diff --git a/Modules/_zstd/decompressor.c b/Modules/_zstd/decompressor.c index 87bb415ba29f5c..a4be180c0088fc 100644 --- a/Modules/_zstd/decompressor.c +++ b/Modules/_zstd/decompressor.c @@ -80,6 +80,14 @@ _PyZstd_set_d_parameters(ZstdDecompressor *self, PyObject *options) pos = 0; while (PyDict_Next(options, &pos, &key, &value)) { + /* Check key type */ + if (Py_TYPE(key) == mod_state->CParameter_type) { + PyErr_SetString(PyExc_TypeError, + "Key of decompression options dict should " + "NOT be CompressionParameter."); + return -1; + } + /* Both key & value should be 32-bit signed int */ int key_v = PyLong_AsInt(key); if (key_v == -1 && PyErr_Occurred()) {