diff --git a/openapi_python_client/config.py b/openapi_python_client/config.py index 21cb4d182..a587d5a30 100644 --- a/openapi_python_client/config.py +++ b/openapi_python_client/config.py @@ -2,7 +2,7 @@ import mimetypes from enum import Enum from pathlib import Path -from typing import Optional, Union +from typing import Literal, Optional, Union from attr import define from pydantic import BaseModel @@ -28,6 +28,9 @@ class MetaType(str, Enum): PDM = "pdm" UV = "uv" +class JSONDecoder(str, Enum): + UJSON = "ujson" + ORJSON = "orjson" class ConfigFile(BaseModel): """Contains any configurable values passed via a config file. @@ -47,6 +50,7 @@ class ConfigFile(BaseModel): generate_all_tags: bool = False http_timeout: int = 5 literal_enums: bool = False + alt_json_decoder: Optional[str] = None @staticmethod def load_from_path(path: Path) -> "ConfigFile": @@ -82,6 +86,7 @@ class Config: content_type_overrides: dict[str, str] overwrite: bool output_path: Optional[Path] + alt_json_decoder: Optional[JSONDecoder] @staticmethod def from_sources( @@ -104,6 +109,14 @@ def from_sources( "ruff check --fix .", "ruff format .", ] + + if config_file.alt_json_decoder == "ujson": + json_decoder = JSONDecoder.UJSON + elif config_file.alt_json_decoder == "orjson": + json_decoder = JSONDecoder.ORJSON + else: + json_decoder = None + config = Config( meta_type=meta_type, @@ -119,6 +132,7 @@ def from_sources( generate_all_tags=config_file.generate_all_tags, http_timeout=config_file.http_timeout, literal_enums=config_file.literal_enums, + alt_json_decoder=json_decoder, document_source=document_source, file_encoding=file_encoding, overwrite=overwrite, diff --git a/openapi_python_client/parser/responses.py b/openapi_python_client/parser/responses.py index 704a35f2d..984f1e533 100644 --- a/openapi_python_client/parser/responses.py +++ b/openapi_python_client/parser/responses.py @@ -37,6 +37,7 @@ class _ResponseSource(TypedDict): JSON_SOURCE = _ResponseSource(attribute="response.json()", return_type="Any") +ALT_JSON_SOURCE = _ResponseSource(attribute="loads(response.content)", return_type="Any") BYTES_SOURCE = _ResponseSource(attribute="response.content", return_type="bytes") TEXT_SOURCE = _ResponseSource(attribute="response.text", return_type="str") NONE_SOURCE = _ResponseSource(attribute="None", return_type="None") @@ -135,15 +136,20 @@ def _source_by_content_type(content_type: str, config: Config) -> Optional[_Resp if parsed_content_type.startswith("text/"): return TEXT_SOURCE + + if config.alt_json_decoder: + USED_JSON_SOURCE = ALT_JSON_SOURCE + else: + USED_JSON_SOURCE = JSON_SOURCE known_content_types = { - "application/json": JSON_SOURCE, + "application/json": USED_JSON_SOURCE, "application/octet-stream": BYTES_SOURCE, } source = known_content_types.get(parsed_content_type) if source is None and parsed_content_type.endswith("+json"): # Implements https://www.rfc-editor.org/rfc/rfc6838#section-4.2.8 for the +json suffix - source = JSON_SOURCE + source = USED_JSON_SOURCE return source diff --git a/openapi_python_client/templates/endpoint_module.py.jinja b/openapi_python_client/templates/endpoint_module.py.jinja index a7b82df90..c3f048ce4 100644 --- a/openapi_python_client/templates/endpoint_module.py.jinja +++ b/openapi_python_client/templates/endpoint_module.py.jinja @@ -2,6 +2,9 @@ from http import HTTPStatus from typing import Any, Optional, Union, cast import httpx +{% if endpoint.json_decoder %} +from {{ endpoint.json_decoder }} import loads +{% endif %} from ...client import AuthenticatedClient, Client from ...types import Response, UNSET diff --git a/openapi_python_client/templates/package_macros.py.jinja b/openapi_python_client/templates/package_macros.py.jinja new file mode 100644 index 000000000..d94f635cf --- /dev/null +++ b/openapi_python_client/templates/package_macros.py.jinja @@ -0,0 +1,15 @@ +{% macro json_package(config) %} +{% if config.alt_json_decoder == "ujson" %} +ujson +{% elif config.alt_json_decoder == "orjson" %} +orjson +{% endif %} +{% endmacro %} + +{% macro json_package_ver(config) %} +{% if config.alt_json_decoder == "ujson" %} +>=5.11.0 +{% elif config.alt_json_decoder == "orjson" %} +>=3.11.3 +{% endif %} +{% endmacro %} \ No newline at end of file diff --git a/openapi_python_client/templates/pyproject_pdm.toml.jinja b/openapi_python_client/templates/pyproject_pdm.toml.jinja index 82b50ea52..fdd8b2717 100644 --- a/openapi_python_client/templates/pyproject_pdm.toml.jinja +++ b/openapi_python_client/templates/pyproject_pdm.toml.jinja @@ -9,6 +9,9 @@ dependencies = [ "httpx>=0.23.0,<0.29.0", "attrs>=22.2.0", "python-dateutil>=2.8.0", +{% if config.alt_json_decoder %} + "{{ json_package(config) }}{{ json_package_vers(config) }}", +{% endif %} ] [tool.pdm] diff --git a/openapi_python_client/templates/pyproject_poetry.toml.jinja b/openapi_python_client/templates/pyproject_poetry.toml.jinja index 9897ddadb..1ae70e0dd 100644 --- a/openapi_python_client/templates/pyproject_poetry.toml.jinja +++ b/openapi_python_client/templates/pyproject_poetry.toml.jinja @@ -14,6 +14,9 @@ python = "^3.9" httpx = ">=0.23.0,<0.29.0" attrs = ">=22.2.0" python-dateutil = "^2.8.0" +{% if config.alt_json_decoder %} +{{ json_package(config) }} = "{{ json_package_vers(config) }}" +{% endif %} [build-system] requires = ["poetry-core>=2.0.0,<3.0.0"] diff --git a/openapi_python_client/templates/pyproject_uv.toml.jinja b/openapi_python_client/templates/pyproject_uv.toml.jinja index 83634d3d6..f3084380c 100644 --- a/openapi_python_client/templates/pyproject_uv.toml.jinja +++ b/openapi_python_client/templates/pyproject_uv.toml.jinja @@ -9,6 +9,9 @@ dependencies = [ "httpx>=0.23.0,<0.29.0", "attrs>=22.2.0", "python-dateutil>=2.8.0,<3", +{% if config.alt_json_decoder %} + "{{ json_package(config) }}{{ json_package_vers(config) }}", +{% endif %} ] [tool.uv.build-backend] diff --git a/tests/test_config.py b/tests/test_config.py index be2e8bf59..a8060b73c 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -7,7 +7,7 @@ import pytest from ruamel.yaml import YAML as _YAML -from openapi_python_client.config import ConfigFile +from openapi_python_client.config import ConfigFile, JSONDecoder class YAML(_YAML): @@ -49,6 +49,7 @@ def test_load_from_path(tmp_path: Path, filename, dump, relative) -> None: "project_name_override": "project-name", "package_name_override": "package_name", "package_version_override": "package_version", + "alt_json_decoder": "orjson", } yml_file.write_text(dump(data)) @@ -59,3 +60,4 @@ def test_load_from_path(tmp_path: Path, filename, dump, relative) -> None: assert config.project_name_override == "project-name" assert config.package_name_override == "package_name" assert config.package_version_override == "package_version" + assert config.alt_json_decoder == JSONDecoder.ORJSON diff --git a/tests/test_parser/test_responses.py b/tests/test_parser/test_responses.py index 22c3ba613..c5e2a2690 100644 --- a/tests/test_parser/test_responses.py +++ b/tests/test_parser/test_responses.py @@ -3,10 +3,12 @@ import pytest import openapi_python_client.schema as oai +from openapi_python_client.config import JSONDecoder from openapi_python_client.parser import responses from openapi_python_client.parser.errors import ParseError, PropertyError from openapi_python_client.parser.properties import Schemas from openapi_python_client.parser.responses import ( + ALT_JSON_SOURCE, JSON_SOURCE, NONE_SOURCE, HTTPStatusPattern, @@ -128,6 +130,7 @@ def test_response_from_data_property(mocker, any_property_factory): ) config = MagicMock() config.content_type_overrides = {} + config.alt_json_decoder = None status_code = HTTPStatusPattern(pattern="400", code_range=(400, 400)) response, schemas = responses.response_from_data( @@ -164,6 +167,7 @@ def test_response_from_data_reference(mocker, any_property_factory): ) config = MagicMock() config.content_type_overrides = {} + config.alt_json_decoder = None response, schemas = responses.response_from_data( status_code=HTTPStatusPattern(pattern="400", code_range=(400, 400)), @@ -182,6 +186,44 @@ def test_response_from_data_reference(mocker, any_property_factory): ) +def test_response_with_alt_decoder(mocker, any_property_factory): + prop = any_property_factory() + property_from_data = mocker.patch.object(responses, "property_from_data", return_value=(prop, Schemas())) + data = oai.Response.model_construct( + description="", + content={"application/json": oai.MediaType.model_construct(media_type_schema="something")}, + ) + config = MagicMock() + config.content_type_overrides = {} + config.alt_json_decoder = JSONDecoder.ORJSON + status_code = HTTPStatusPattern(pattern="400", code_range=(400, 400)) + + response, schemas = responses.response_from_data( + status_code=status_code, + data=data, + schemas=Schemas(), + responses={}, + parent_name="parent", + config=config, + ) + + assert response == responses.Response( + status_code=status_code, + prop=prop, + source=ALT_JSON_SOURCE, + data=data, + + ) + property_from_data.assert_called_once_with( + name="response_400", + required=True, + data="something", + schemas=Schemas(), + parent_name="parent", + config=config, + ) + + @pytest.mark.parametrize( "ref_string,expected_error_string", [