Skip to content

Feature/support header parameters #148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0


## 0.5.3 - Unrelease
### Additions
- Added support for header parameters (#117)

### Fixes
- JSON bodies will now be assigned correctly in generated clients(#139 & #147). Thanks @pawamoy!

Expand Down
4 changes: 2 additions & 2 deletions end_to_end_tests/fastapi_app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pathlib import Path
from typing import Any, Dict, List, Union

from fastapi import APIRouter, FastAPI, File, Query, UploadFile
from fastapi import APIRouter, FastAPI, File, Header, Query, UploadFile
from pydantic import BaseModel

app = FastAPI(title="My Test API", description="An API for testing openapi-python-client",)
Expand Down Expand Up @@ -55,7 +55,7 @@ def get_list(an_enum_value: List[AnEnum] = Query(...), some_date: Union[date, da


@test_router.post("/upload")
async def upload_file(some_file: UploadFile = File(...)):
async def upload_file(some_file: UploadFile = File(...), keep_alive: bool = Header(None)):
""" Upload a file """
data = await some_file.read()
return (some_file.filename, some_file.content_type, data)
Expand Down
11 changes: 11 additions & 0 deletions end_to_end_tests/fastapi_app/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,17 @@
"summary": "Upload File",
"description": "Upload a file ",
"operationId": "upload_file_tests_upload_post",
"parameters": [
{
"required": false,
"schema": {
"title": "Keep-Alive",
"type": "boolean"
},
"name": "keep-alive",
"in": "header"
}
],
"requestBody": {
"content": {
"multipart/form-data": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ def ping_ping_get(*, client: Client,) -> bool:
""" A quick check to see if the system is running """
url = "{}/ping".format(client.base_url)

response = httpx.get(url=url, headers=client.get_headers(),)
headers: Dict[str, Any] = client.get_headers()

response = httpx.get(url=url, headers=headers,)

if response.status_code == 200:
return bool(response.text)
Expand Down
16 changes: 12 additions & 4 deletions end_to_end_tests/golden-master/my_test_api_client/api/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def get_user_list(
""" Get a list of things """
url = "{}/tests/".format(client.base_url)

headers: Dict[str, Any] = client.get_headers()

json_an_enum_value = []
for an_enum_value_item_data in an_enum_value:
an_enum_value_item = an_enum_value_item_data.value
Expand All @@ -38,7 +40,7 @@ def get_user_list(
"some_date": json_some_date,
}

response = httpx.get(url=url, headers=client.get_headers(), params=params,)
response = httpx.get(url=url, headers=headers, params=params,)

if response.status_code == 200:
return [AModel.from_dict(item) for item in cast(List[Dict[str, Any]], response.json())]
Expand All @@ -49,15 +51,19 @@ def get_user_list(


def upload_file_tests_upload_post(
*, client: Client, multipart_data: BodyUploadFileTestsUploadPost,
*, client: Client, multipart_data: BodyUploadFileTestsUploadPost, keep_alive: Optional[bool] = None,
) -> Union[
None, HTTPValidationError,
]:

""" Upload a file """
url = "{}/tests/upload".format(client.base_url)

response = httpx.post(url=url, headers=client.get_headers(), files=multipart_data.to_dict(),)
headers: Dict[str, Any] = client.get_headers()
if keep_alive is not None:
headers["keep-alive"] = keep_alive

response = httpx.post(url=url, headers=headers, files=multipart_data.to_dict(),)

if response.status_code == 200:
return None
Expand All @@ -76,9 +82,11 @@ def json_body_tests_json_body_post(
""" Try sending a JSON body """
url = "{}/tests/json_body".format(client.base_url)

headers: Dict[str, Any] = client.get_headers()

json_json_body = json_body.to_dict()

response = httpx.post(url=url, headers=client.get_headers(), json=json_json_body,)
response = httpx.post(url=url, headers=headers, json=json_json_body,)

if response.status_code == 200:
return None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ async def ping_ping_get(*, client: Client,) -> bool:
""" A quick check to see if the system is running """
url = "{}/ping".format(client.base_url,)

headers: Dict[str, Any] = client.get_headers()

async with httpx.AsyncClient() as _client:
response = await _client.get(url=url, headers=client.get_headers(),)
response = await _client.get(url=url, headers=headers,)

if response.status_code == 200:
return bool(response.text)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ async def get_user_list(
""" Get a list of things """
url = "{}/tests/".format(client.base_url,)

headers: Dict[str, Any] = client.get_headers()

json_an_enum_value = []
for an_enum_value_item_data in an_enum_value:
an_enum_value_item = an_enum_value_item_data.value
Expand All @@ -39,7 +41,7 @@ async def get_user_list(
}

async with httpx.AsyncClient() as _client:
response = await _client.get(url=url, headers=client.get_headers(), params=params,)
response = await _client.get(url=url, headers=headers, params=params,)

if response.status_code == 200:
return [AModel.from_dict(item) for item in cast(List[Dict[str, Any]], response.json())]
Expand All @@ -50,16 +52,20 @@ async def get_user_list(


async def upload_file_tests_upload_post(
*, client: Client, multipart_data: BodyUploadFileTestsUploadPost,
*, client: Client, multipart_data: BodyUploadFileTestsUploadPost, keep_alive: Optional[bool] = None,
) -> Union[
None, HTTPValidationError,
]:

""" Upload a file """
url = "{}/tests/upload".format(client.base_url,)

headers: Dict[str, Any] = client.get_headers()
if keep_alive is not None:
headers["keep-alive"] = keep_alive

async with httpx.AsyncClient() as _client:
response = await _client.post(url=url, headers=client.get_headers(), files=multipart_data.to_dict(),)
response = await _client.post(url=url, headers=headers, files=multipart_data.to_dict(),)

if response.status_code == 200:
return None
Expand All @@ -78,10 +84,12 @@ async def json_body_tests_json_body_post(
""" Try sending a JSON body """
url = "{}/tests/json_body".format(client.base_url,)

headers: Dict[str, Any] = client.get_headers()

json_json_body = json_body.to_dict()

async with httpx.AsyncClient() as _client:
response = await _client.post(url=url, headers=client.get_headers(), json=json_json_body,)
response = await _client.post(url=url, headers=headers, json=json_json_body,)

if response.status_code == 200:
return None
Expand Down
2 changes: 1 addition & 1 deletion openapi_python_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def _get_document(*, url: Optional[str], path: Optional[Path]) -> Union[Dict[str


class Project:
TEMPLATE_FILTERS = {"snakecase": utils.snake_case}
TEMPLATE_FILTERS = {"snakecase": utils.snake_case, "spinalcase": utils.spinal_case}
project_name_override: Optional[str] = None
package_name_override: Optional[str] = None

Expand Down
4 changes: 4 additions & 0 deletions openapi_python_client/parser/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class ParameterLocation(str, Enum):

QUERY = "query"
PATH = "path"
HEADER = "header"


def import_string_from_reference(reference: Reference, prefix: str = "") -> str:
Expand Down Expand Up @@ -78,6 +79,7 @@ class Endpoint:
relative_imports: Set[str] = field(default_factory=set)
query_parameters: List[Property] = field(default_factory=list)
path_parameters: List[Property] = field(default_factory=list)
header_parameters: List[Property] = field(default_factory=list)
responses: List[Response] = field(default_factory=list)
form_body_reference: Optional[Reference] = None
json_body: Optional[Property] = None
Expand Down Expand Up @@ -164,6 +166,8 @@ def _add_parameters(endpoint: Endpoint, data: oai.Operation) -> Union[Endpoint,
endpoint.query_parameters.append(prop)
elif param.param_in == ParameterLocation.PATH:
endpoint.path_parameters.append(prop)
elif param.param_in == ParameterLocation.HEADER:
endpoint.header_parameters.append(prop)
else:
return ParseError(data=param, detail="Parameter must be declared in path or query")
return endpoint
Expand Down
10 changes: 8 additions & 2 deletions openapi_python_client/templates/async_endpoint_module.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ from ..errors import ApiResponseError
{% endfor %}
{% for endpoint in collection.endpoints %}

{% from "endpoint_macros.pyi" import query_params, json_body, return_type %}
{% from "endpoint_macros.pyi" import header_params, query_params, json_body, return_type %}

async def {{ endpoint.name | snakecase }}(
*,
Expand Down Expand Up @@ -41,6 +41,9 @@ async def {{ endpoint.name | snakecase }}(
{% for parameter in endpoint.query_parameters %}
{{ parameter.to_string() }},
{% endfor %}
{% for parameter in endpoint.header_parameters %}
{{ parameter.to_string() }},
{% endfor %}
{{ return_type(endpoint) }}
""" {{ endpoint.description }} """
url = "{}{{ endpoint.path }}".format(
Expand All @@ -50,13 +53,16 @@ async def {{ endpoint.name | snakecase }}(
{% endfor %}
)

headers: Dict[str, Any] = client.get_headers()
{{ header_params(endpoint) | indent(4) }}

{{ query_params(endpoint) | indent(4) }}
{{ json_body(endpoint) | indent(4) }}

async with httpx.AsyncClient() as _client:
response = await _client.{{ endpoint.method }}(
url=url,
headers=client.get_headers(),
headers=headers,
{% if endpoint.form_body_reference %}
data=asdict(form_data),
{% endif %}
Expand Down
13 changes: 13 additions & 0 deletions openapi_python_client/templates/endpoint_macros.pyi
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
{% macro header_params(endpoint) %}
{% if endpoint.header_parameters %}
{% for parameter in endpoint.header_parameters %}
{% if parameter.required %}
headers["{{ parameter.python_name | spinalcase}}"] = {{ parameter.python_name }}
{% else %}
if {{ parameter.python_name }} is not None:
headers["{{ parameter.python_name | spinalcase}}"] = {{ parameter.python_name }}
{% endif %}
{% endfor %}
{% endif %}
{% endmacro %}

{% macro query_params(endpoint) %}
{% if endpoint.query_parameters %}
{% for property in endpoint.query_parameters %}
Expand Down
10 changes: 8 additions & 2 deletions openapi_python_client/templates/endpoint_module.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ from ..errors import ApiResponseError
{% endfor %}
{% for endpoint in collection.endpoints %}

{% from "endpoint_macros.pyi" import query_params, json_body, return_type %}
{% from "endpoint_macros.pyi" import header_params, query_params, json_body, return_type %}

def {{ endpoint.name | snakecase }}(
*,
Expand Down Expand Up @@ -41,6 +41,9 @@ def {{ endpoint.name | snakecase }}(
{% for parameter in endpoint.query_parameters %}
{{ parameter.to_string() }},
{% endfor %}
{% for parameter in endpoint.header_parameters %}
{{ parameter.to_string() }},
{% endfor %}
{{ return_type(endpoint) }}
""" {{ endpoint.description }} """
url = "{}{{ endpoint.path }}".format(
Expand All @@ -50,14 +53,17 @@ def {{ endpoint.name | snakecase }}(
{%- endfor -%}
)

headers: Dict[str, Any] = client.get_headers()
{{ header_params(endpoint) | indent(4) }}

{{ query_params(endpoint) | indent(4) }}

{{ json_body(endpoint) | indent(4) }}


response = httpx.{{ endpoint.method }}(
url=url,
headers=client.get_headers(),
headers=headers,
{% if endpoint.form_body_reference %}
data=asdict(form_data),
{% endif %}
Expand Down
4 changes: 4 additions & 0 deletions openapi_python_client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ def snake_case(value: str) -> str:

def pascal_case(value: str) -> str:
return stringcase.pascalcase(value)


def spinal_case(value: str) -> str:
return stringcase.spinalcase(value)
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ isort .\
openapi = "python -m end_to_end_tests.fastapi_app"
gm = "python -m end_to_end_tests.regen_golden_master"
e2e = "pytest openapi_python_client end_to_end_tests"
oge = """
task openapi\
&& task gm\
&& task e2e\
"""

[tool.black]
line-length = 120
Expand Down
20 changes: 14 additions & 6 deletions tests/test_openapi_parser/test_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,9 +452,15 @@ def test__add_parameters_happy(self, mocker):
query_prop = mocker.MagicMock(autospec=Property)
query_prop_import = mocker.MagicMock()
query_prop.get_imports = mocker.MagicMock(return_value={query_prop_import})
property_from_data = mocker.patch(f"{MODULE_NAME}.property_from_data", side_effect=[path_prop, query_prop])
header_prop = mocker.MagicMock(autospec=Property)
header_prop_import = mocker.MagicMock()
header_prop.get_imports = mocker.MagicMock(return_value={header_prop_import})
property_from_data = mocker.patch(
f"{MODULE_NAME}.property_from_data", side_effect=[path_prop, query_prop, header_prop]
)
path_schema = mocker.MagicMock()
query_schema = mocker.MagicMock()
header_schema = mocker.MagicMock()
data = oai.Operation.construct(
parameters=[
oai.Parameter.construct(
Expand All @@ -463,6 +469,9 @@ def test__add_parameters_happy(self, mocker):
oai.Parameter.construct(
name="query_prop_name", required=False, param_schema=query_schema, param_in="query"
),
oai.Parameter.construct(
name="header_prop_name", required=False, param_schema=header_schema, param_in="header"
),
oai.Reference.construct(), # Should be ignored
oai.Parameter.construct(), # Should be ignored
]
Expand All @@ -474,17 +483,16 @@ def test__add_parameters_happy(self, mocker):
[
mocker.call(name="path_prop_name", required=True, data=path_schema),
mocker.call(name="query_prop_name", required=False, data=query_schema),
mocker.call(name="header_prop_name", required=False, data=header_schema),
]
)
path_prop.get_imports.assert_called_once_with(prefix="..models")
query_prop.get_imports.assert_called_once_with(prefix="..models")
assert endpoint.relative_imports == {
"import_3",
path_prop_import,
query_prop_import,
}
header_prop.get_imports.assert_called_once_with(prefix="..models")
assert endpoint.relative_imports == {"import_3", path_prop_import, query_prop_import, header_prop_import}
assert endpoint.path_parameters == [path_prop]
assert endpoint.query_parameters == [query_prop]
assert endpoint.header_parameters == [header_prop]

def test_from_data_bad_params(self, mocker):
from openapi_python_client.parser.openapi import Endpoint
Expand Down
Loading