Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion openapi_python_client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import mimetypes
from enum import Enum
from pathlib import Path
from typing import Optional, Union
from typing import Literal, Optional, Union

from attr import define
from pydantic import BaseModel
Expand All @@ -28,6 +28,9 @@ class MetaType(str, Enum):
PDM = "pdm"
UV = "uv"

class JSONDecoder(str, Enum):
UJSON = "ujson"
ORJSON = "orjson"

class ConfigFile(BaseModel):
"""Contains any configurable values passed via a config file.
Expand All @@ -47,6 +50,7 @@ class ConfigFile(BaseModel):
generate_all_tags: bool = False
http_timeout: int = 5
literal_enums: bool = False
alt_json_decoder: Optional[str] = None

@staticmethod
def load_from_path(path: Path) -> "ConfigFile":
Expand Down Expand Up @@ -82,6 +86,7 @@ class Config:
content_type_overrides: dict[str, str]
overwrite: bool
output_path: Optional[Path]
alt_json_decoder: Optional[JSONDecoder]

@staticmethod
def from_sources(
Expand All @@ -104,6 +109,14 @@ def from_sources(
"ruff check --fix .",
"ruff format .",
]

if config_file.alt_json_decoder == "ujson":
json_decoder = JSONDecoder.UJSON
elif config_file.alt_json_decoder == "orjson":
json_decoder = JSONDecoder.ORJSON
else:
json_decoder = None


config = Config(
meta_type=meta_type,
Expand All @@ -119,6 +132,7 @@ def from_sources(
generate_all_tags=config_file.generate_all_tags,
http_timeout=config_file.http_timeout,
literal_enums=config_file.literal_enums,
alt_json_decoder=json_decoder,
document_source=document_source,
file_encoding=file_encoding,
overwrite=overwrite,
Expand Down
10 changes: 8 additions & 2 deletions openapi_python_client/parser/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class _ResponseSource(TypedDict):


JSON_SOURCE = _ResponseSource(attribute="response.json()", return_type="Any")
ALT_JSON_SOURCE = _ResponseSource(attribute="loads(response.content)", return_type="Any")
BYTES_SOURCE = _ResponseSource(attribute="response.content", return_type="bytes")
TEXT_SOURCE = _ResponseSource(attribute="response.text", return_type="str")
NONE_SOURCE = _ResponseSource(attribute="None", return_type="None")
Expand Down Expand Up @@ -135,15 +136,20 @@ def _source_by_content_type(content_type: str, config: Config) -> Optional[_Resp

if parsed_content_type.startswith("text/"):
return TEXT_SOURCE

if config.alt_json_decoder:
USED_JSON_SOURCE = ALT_JSON_SOURCE
else:
USED_JSON_SOURCE = JSON_SOURCE

known_content_types = {
"application/json": JSON_SOURCE,
"application/json": USED_JSON_SOURCE,
"application/octet-stream": BYTES_SOURCE,
}
source = known_content_types.get(parsed_content_type)
if source is None and parsed_content_type.endswith("+json"):
# Implements https://www.rfc-editor.org/rfc/rfc6838#section-4.2.8 for the +json suffix
source = JSON_SOURCE
source = USED_JSON_SOURCE
return source


Expand Down
3 changes: 3 additions & 0 deletions openapi_python_client/templates/endpoint_module.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ from http import HTTPStatus
from typing import Any, Optional, Union, cast

import httpx
{% if endpoint.json_decoder %}
from {{ endpoint.json_decoder }} import loads
{% endif %}

from ...client import AuthenticatedClient, Client
from ...types import Response, UNSET
Expand Down
15 changes: 15 additions & 0 deletions openapi_python_client/templates/package_macros.py.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{% macro json_package(config) %}
{% if config.alt_json_decoder == "ujson" %}
ujson
{% elif config.alt_json_decoder == "orjson" %}
orjson
{% endif %}
{% endmacro %}

{% macro json_package_ver(config) %}
{% if config.alt_json_decoder == "ujson" %}
>=5.11.0
{% elif config.alt_json_decoder == "orjson" %}
>=3.11.3
{% endif %}
{% endmacro %}
3 changes: 3 additions & 0 deletions openapi_python_client/templates/pyproject_pdm.toml.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ dependencies = [
"httpx>=0.23.0,<0.29.0",
"attrs>=22.2.0",
"python-dateutil>=2.8.0",
{% if config.alt_json_decoder %}
"{{ json_package(config) }}{{ json_package_vers(config) }}",
{% endif %}
]

[tool.pdm]
Expand Down
3 changes: 3 additions & 0 deletions openapi_python_client/templates/pyproject_poetry.toml.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ python = "^3.9"
httpx = ">=0.23.0,<0.29.0"
attrs = ">=22.2.0"
python-dateutil = "^2.8.0"
{% if config.alt_json_decoder %}
{{ json_package(config) }} = "{{ json_package_vers(config) }}"
{% endif %}

[build-system]
requires = ["poetry-core>=2.0.0,<3.0.0"]
Expand Down
3 changes: 3 additions & 0 deletions openapi_python_client/templates/pyproject_uv.toml.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ dependencies = [
"httpx>=0.23.0,<0.29.0",
"attrs>=22.2.0",
"python-dateutil>=2.8.0,<3",
{% if config.alt_json_decoder %}
"{{ json_package(config) }}{{ json_package_vers(config) }}",
{% endif %}
]
[tool.uv.build-backend]
Expand Down
4 changes: 3 additions & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
from ruamel.yaml import YAML as _YAML

from openapi_python_client.config import ConfigFile
from openapi_python_client.config import ConfigFile, JSONDecoder


class YAML(_YAML):
Expand Down Expand Up @@ -49,6 +49,7 @@ def test_load_from_path(tmp_path: Path, filename, dump, relative) -> None:
"project_name_override": "project-name",
"package_name_override": "package_name",
"package_version_override": "package_version",
"alt_json_decoder": "orjson",
}
yml_file.write_text(dump(data))

Expand All @@ -59,3 +60,4 @@ def test_load_from_path(tmp_path: Path, filename, dump, relative) -> None:
assert config.project_name_override == "project-name"
assert config.package_name_override == "package_name"
assert config.package_version_override == "package_version"
assert config.alt_json_decoder == JSONDecoder.ORJSON
42 changes: 42 additions & 0 deletions tests/test_parser/test_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import pytest

import openapi_python_client.schema as oai
from openapi_python_client.config import JSONDecoder
from openapi_python_client.parser import responses
from openapi_python_client.parser.errors import ParseError, PropertyError
from openapi_python_client.parser.properties import Schemas
from openapi_python_client.parser.responses import (
ALT_JSON_SOURCE,
JSON_SOURCE,
NONE_SOURCE,
HTTPStatusPattern,
Expand Down Expand Up @@ -128,6 +130,7 @@ def test_response_from_data_property(mocker, any_property_factory):
)
config = MagicMock()
config.content_type_overrides = {}
config.alt_json_decoder = None
status_code = HTTPStatusPattern(pattern="400", code_range=(400, 400))

response, schemas = responses.response_from_data(
Expand Down Expand Up @@ -164,6 +167,7 @@ def test_response_from_data_reference(mocker, any_property_factory):
)
config = MagicMock()
config.content_type_overrides = {}
config.alt_json_decoder = None

response, schemas = responses.response_from_data(
status_code=HTTPStatusPattern(pattern="400", code_range=(400, 400)),
Expand All @@ -182,6 +186,44 @@ def test_response_from_data_reference(mocker, any_property_factory):
)


def test_response_with_alt_decoder(mocker, any_property_factory):
prop = any_property_factory()
property_from_data = mocker.patch.object(responses, "property_from_data", return_value=(prop, Schemas()))
data = oai.Response.model_construct(
description="",
content={"application/json": oai.MediaType.model_construct(media_type_schema="something")},
)
config = MagicMock()
config.content_type_overrides = {}
config.alt_json_decoder = JSONDecoder.ORJSON
status_code = HTTPStatusPattern(pattern="400", code_range=(400, 400))

response, schemas = responses.response_from_data(
status_code=status_code,
data=data,
schemas=Schemas(),
responses={},
parent_name="parent",
config=config,
)

assert response == responses.Response(
status_code=status_code,
prop=prop,
source=ALT_JSON_SOURCE,
data=data,

)
property_from_data.assert_called_once_with(
name="response_400",
required=True,
data="something",
schemas=Schemas(),
parent_name="parent",
config=config,
)


@pytest.mark.parametrize(
"ref_string,expected_error_string",
[
Expand Down