diff --git a/src/check_jsonschema/cli/main_command.py b/src/check_jsonschema/cli/main_command.py index 8f7763994..b384af478 100644 --- a/src/check_jsonschema/cli/main_command.py +++ b/src/check_jsonschema/cli/main_command.py @@ -20,7 +20,7 @@ SchemaLoaderBase, ) from ..transforms import TRANSFORM_LIBRARY -from .param_types import CommaDelimitedList, ValidatorClassName +from .param_types import CommaDelimitedList, LazyBinaryReadFile, ValidatorClassName from .parse_result import ParseResult, SchemaLoadingMode BUILTIN_SCHEMA_NAMES = [f"vendor.{k}" for k in SCHEMA_CATALOG.keys()] + [ @@ -220,7 +220,9 @@ def pretty_helptext_list(values: list[str] | tuple[str, ...]) -> str: help="Reduce output verbosity", count=True, ) -@click.argument("instancefiles", required=True, nargs=-1, type=click.File("rb")) +@click.argument( + "instancefiles", required=True, nargs=-1, type=LazyBinaryReadFile("rb", lazy=True) +) def main( *, schemafile: str | None, diff --git a/src/check_jsonschema/cli/param_types.py b/src/check_jsonschema/cli/param_types.py index 7c80f8669..234344c95 100644 --- a/src/check_jsonschema/cli/param_types.py +++ b/src/check_jsonschema/cli/param_types.py @@ -1,11 +1,14 @@ from __future__ import annotations import importlib +import os import re +import stat import typing as t import click import jsonschema +from click._compat import open_stream class CommaDelimitedList(click.ParamType): @@ -104,3 +107,50 @@ def convert( self.fail(f"'{classname}' in '{pkg}' is not a class", param, ctx) return t.cast(t.Type[jsonschema.protocols.Validator], result) + + +class CustomLazyFile(click.utils.LazyFile): + def __init__( + self, + filename: str | os.PathLike[str], + mode: str = "r", + encoding: str | None = None, + errors: str | None = "strict", + atomic: bool = False, + ): + self.name: str = os.fspath(filename) + self.mode = mode + self.encoding = encoding + self.errors = errors + self.atomic = atomic + self._f: t.IO[t.Any] | None + self.should_close: bool + + if self.name == "-": + self._f, self.should_close = open_stream(filename, mode, encoding, errors) + else: + if "r" in mode and not stat.S_ISFIFO(os.stat(filename).st_mode): + # Open and close the file in case we're opening it for + # reading so that we can catch at least some errors in + # some cases early. + open(filename, mode).close() + self._f = None + self.should_close = True + + +class LazyBinaryReadFile(click.File): + def convert( + self, + value: str | os.PathLike[str] | t.IO[t.Any], + param: click.Parameter | None, + ctx: click.Context | None, + ) -> t.IO[bytes]: + if hasattr(value, "read") or hasattr(value, "write"): + return t.cast(t.IO[bytes], value) + + value_: str | os.PathLike[str] = t.cast("str | os.PathLike[str]", value) + + lf = CustomLazyFile(value_, mode="rb") + if ctx is not None: + ctx.call_on_close(lf.close_intelligently) + return t.cast(t.IO[bytes], lf) diff --git a/src/check_jsonschema/instance_loader.py b/src/check_jsonschema/instance_loader.py index d025ab8ce..5c2acca6f 100644 --- a/src/check_jsonschema/instance_loader.py +++ b/src/check_jsonschema/instance_loader.py @@ -3,6 +3,8 @@ import io import typing as t +from check_jsonschema.cli.param_types import CustomLazyFile + from .parsers import ParseError, ParserSet from .transforms import Transform @@ -10,7 +12,7 @@ class InstanceLoader: def __init__( self, - files: t.Sequence[t.BinaryIO], + files: t.Sequence[t.BinaryIO | CustomLazyFile], default_filetype: str = "json", data_transform: Transform | None = None, ) -> None: @@ -35,12 +37,21 @@ def iter_files(self) -> t.Iterator[tuple[str, ParseError | t.Any]]: name = "" else: raise ValueError(f"File {file} has no name attribute") + try: - data: t.Any = self._parsers.parse_data_with_path( - file, name, self._default_filetype - ) - except ParseError as err: - data = err - else: - data = self._data_transform(data) + if isinstance(file, CustomLazyFile): + stream: t.BinaryIO = t.cast(t.BinaryIO, file.open()) + else: + stream = file + + try: + data: t.Any = self._parsers.parse_data_with_path( + stream, name, self._default_filetype + ) + except ParseError as err: + data = err + else: + data = self._data_transform(data) + finally: + file.close() yield (name, data) diff --git a/src/check_jsonschema/schema_loader/readers.py b/src/check_jsonschema/schema_loader/readers.py index 244fa4119..907ce6936 100644 --- a/src/check_jsonschema/schema_loader/readers.py +++ b/src/check_jsonschema/schema_loader/readers.py @@ -15,6 +15,13 @@ yaml = ruamel.yaml.YAML(typ="safe") +class _UnsetType: + pass + + +_UNSET = _UnsetType() + + def _run_load_callback(schema_location: str, callback: t.Callable) -> dict: try: schema = callback() @@ -31,6 +38,7 @@ def __init__(self, filename: str) -> None: self.path = filename2path(filename) self.filename = str(self.path) self.parsers = ParserSet() + self._parsed_schema: dict | _UnsetType = _UNSET def get_retrieval_uri(self) -> str | None: return self.path.as_uri() @@ -39,21 +47,26 @@ def _read_impl(self) -> t.Any: return self.parsers.parse_file(self.path, default_filetype="json") def read_schema(self) -> dict: - return _run_load_callback(self.filename, self._read_impl) + if self._parsed_schema is _UNSET: + self._parsed_schema = _run_load_callback(self.filename, self._read_impl) + return t.cast(dict, self._parsed_schema) class StdinSchemaReader: def __init__(self) -> None: self.parsers = ParserSet() + self._parsed_schema: dict | _UnsetType = _UNSET def get_retrieval_uri(self) -> str | None: return None def read_schema(self) -> dict: - try: - return json.load(sys.stdin) - except ValueError as e: - raise ParseError("Failed to parse JSON from stdin") from e + if self._parsed_schema is _UNSET: + try: + self._parsed_schema = json.load(sys.stdin) + except ValueError as e: + raise ParseError("Failed to parse JSON from stdin") from e + return t.cast(dict, self._parsed_schema) class HttpSchemaReader: @@ -71,14 +84,12 @@ def __init__( disable_cache=disable_cache, validation_callback=self._parse, ) - self._parsed_schema: t.Any | None = None + self._parsed_schema: dict | _UnsetType = _UNSET def _parse(self, schema_bytes: bytes) -> t.Any: - if self._parsed_schema is None: - self._parsed_schema = self.parsers.parse_data_with_path( - io.BytesIO(schema_bytes), self.url, default_filetype="json" - ) - return self._parsed_schema + return self.parsers.parse_data_with_path( + io.BytesIO(schema_bytes), self.url, default_filetype="json" + ) def get_retrieval_uri(self) -> str | None: return self.url @@ -88,4 +99,6 @@ def _read_impl(self) -> t.Any: return self._parse(fp.read()) def read_schema(self) -> dict: - return _run_load_callback(self.url, self._read_impl) + if self._parsed_schema is _UNSET: + self._parsed_schema = _run_load_callback(self.url, self._read_impl) + return t.cast(dict, self._parsed_schema) diff --git a/tests/acceptance/test_special_filetypes.py b/tests/acceptance/test_special_filetypes.py index c148913b7..70ca1cdcf 100644 --- a/tests/acceptance/test_special_filetypes.py +++ b/tests/acceptance/test_special_filetypes.py @@ -1,7 +1,7 @@ +import multiprocessing import os import platform import sys -import threading import pytest import responses @@ -33,6 +33,16 @@ def test_schema_and_instance_in_memfds(run_line_simple): os.close(instancefd) +# helper (in global scope) for multiprocessing "spawn" to be able to use to launch +# background writers +def _fifo_write(path, data): + fd = os.open(path, os.O_WRONLY) + try: + os.write(fd, data) + finally: + os.close(fd) + + @pytest.mark.skipif(os.name != "posix", reason="test requires mkfifo") @pytest.mark.parametrize("check_succeeds", (True, False)) def test_schema_and_instance_in_fifos(tmp_path, run_line, check_succeeds): @@ -45,25 +55,17 @@ def test_schema_and_instance_in_fifos(tmp_path, run_line, check_succeeds): os.mkfifo(schema_path) os.mkfifo(instance_path) - # execute FIFO writes as blocking writes in background threads - # nonblocking writes fail file existence if there's no reader, so using a FIFO - # requires some level of concurrency - def fifo_write(path, data): - fd = os.open(path, os.O_WRONLY) - try: - os.write(fd, data) - finally: - os.close(fd) - - schema_thread = threading.Thread( - target=fifo_write, args=[schema_path, b'{"type": "integer"}'] + spawn_ctx = multiprocessing.get_context("spawn") + + schema_proc = spawn_ctx.Process( + target=_fifo_write, args=(schema_path, b'{"type": "integer"}') ) - schema_thread.start() + schema_proc.start() instance_data = b"42" if check_succeeds else b'"foo"' - instance_thread = threading.Thread( - target=fifo_write, args=[instance_path, instance_data] + instance_proc = spawn_ctx.Process( + target=_fifo_write, args=(instance_path, instance_data) ) - instance_thread.start() + instance_proc.start() try: result = run_line( @@ -74,8 +76,8 @@ def fifo_write(path, data): else: assert result.exit_code == 1 finally: - schema_thread.join(timeout=0.1) - instance_thread.join(timeout=0.1) + schema_proc.terminate() + instance_proc.terminate() @pytest.mark.parametrize("check_passes", (True, False)) diff --git a/tests/unit/test_cli_parse.py b/tests/unit/test_cli_parse.py index f5c1e56b9..fd376ef32 100644 --- a/tests/unit/test_cli_parse.py +++ b/tests/unit/test_cli_parse.py @@ -1,6 +1,5 @@ from __future__ import annotations -import io from unittest import mock import click @@ -86,7 +85,7 @@ def test_schemafile_and_instancefile(runner, mock_parse_result, in_tmp_dir, tmp_ assert mock_parse_result.schema_path == "schema.json" assert isinstance(mock_parse_result.instancefiles, tuple) for f in mock_parse_result.instancefiles: - assert isinstance(f, (io.BytesIO, io.BufferedReader)) + assert isinstance(f, click.utils.LazyFile) assert tuple(f.name for f in mock_parse_result.instancefiles) == ("foo.json",) diff --git a/tests/unit/test_lazy_file_handling.py b/tests/unit/test_lazy_file_handling.py new file mode 100644 index 000000000..dd69eac60 --- /dev/null +++ b/tests/unit/test_lazy_file_handling.py @@ -0,0 +1,46 @@ +import os +import platform + +import pytest +from click.testing import CliRunner + +from check_jsonschema.cli.main_command import build_checker +from check_jsonschema.cli.main_command import main as cli_main + + +@pytest.fixture +def runner() -> CliRunner: + return CliRunner(mix_stderr=False) + + +@pytest.mark.skipif( + platform.system() != "Linux", reason="test requires /proc/self/ mechanism" +) +def test_open_file_usage_never_exceeds_1000(runner, monkeypatch, tmp_path): + schema_path = tmp_path / "schema.json" + schema_path.write_text("{}") + + args = [ + "--schemafile", + str(schema_path), + ] + + for i in range(2000): + instance_path = tmp_path / f"file{i}.json" + instance_path.write_text("{}") + args.append(str(instance_path)) + + checker = None + + def fake_execute(argv): + nonlocal checker + checker = build_checker(argv) + + monkeypatch.setattr("check_jsonschema.cli.main_command.execute", fake_execute) + res = runner.invoke(cli_main, args) + assert res.exit_code == 0, res.stderr + + assert checker is not None + assert len(os.listdir("/proc/self/fd")) < 2000 + for _fname, _data in checker._instance_loader.iter_files(): + assert len(os.listdir("/proc/self/fd")), 2000 diff --git a/tox.ini b/tox.ini index 1a475ca0f..709c6a90d 100644 --- a/tox.ini +++ b/tox.ini @@ -46,12 +46,10 @@ commands = coverage report --skip-covered [testenv:mypy] description = "check type annotations with mypy" -# temporarily pin back click until either click 8.1.5 releases or mypy fixes the issue -# with referential integrity of type aliases deps = mypy types-jsonschema types-requests - click==8.1.3 + click commands = mypy src/ {posargs} [testenv:pyright]