Skip to content

Convert property and endpoint names to snake_case #36

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 25, 2020
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixes
- Fixed some typing issues in generated clients and incorporate mypy into end to end tests (#32). Thanks @acgray!
- Properly handle camelCase endpoint names and properties (#29, #36). Thanks @acgray!

## 0.2.1 - 2020-03-22
### Fixes
Expand Down
13 changes: 7 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,17 @@ You can pass a YAML (or JSON) file to openapi-python-client in order to change s
are supported:

### class_overrides
Used to change the name of generated model classes, especially useful if you have a name like ABCModel which, when
converted to snake case for module naming will be a_b_c_model. This param should be a mapping of existing class name
(usually a key in the "schemas" section of your OpenAPI document) to class_name and module_name.
Used to change the name of generated model classes. This param should be a mapping of existing class name
(usually a key in the "schemas" section of your OpenAPI document) to class_name and module_name. As an example, if the
name of the a model in OpenAPI (and therefore the generated class name) was something like "_PrivateInternalLongName"
and you want the generated client's model to be called "ShortName" in a module called "short_name" you could do this:

Example:
```yaml
class_overrides:
ABCModel:
class_name: ABCModel
module_name: abc_model
_PrivateInternalLongName:
class_name: ShortName
module_name: short_name
```

The easiest way to find what needs to be overridden is probably to generate your client and go look at everything in the
Expand Down
6 changes: 6 additions & 0 deletions openapi_python_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import yaml
from jinja2 import Environment, PackageLoader

from openapi_python_client import utils

from .openapi_parser import OpenAPI, import_string_from_reference

__version__ = version(__package__)
Expand Down Expand Up @@ -61,6 +63,8 @@ def _get_json(*, url: Optional[str], path: Optional[Path]) -> Dict[str, Any]:


class _Project:
TEMPLATE_FILTERS = {"snakecase": utils.snake_case}

def __init__(self, *, openapi: OpenAPI) -> None:
self.openapi: OpenAPI = openapi
self.env: Environment = Environment(loader=PackageLoader(__package__), trim_blocks=True, lstrip_blocks=True)
Expand All @@ -72,6 +76,8 @@ def __init__(self, *, openapi: OpenAPI) -> None:
self.package_dir: Path = self.project_dir / self.package_name
self.package_description = f"A client library for accessing {self.openapi.title}"

self.env.filters.update(self.TEMPLATE_FILTERS)

def build(self) -> None:
""" Create the project from templates """

Expand Down
20 changes: 15 additions & 5 deletions openapi_python_client/openapi_parser/properties.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from dataclasses import dataclass, field
from typing import Any, ClassVar, Dict, List, Optional

from openapi_python_client import utils

from .reference import Reference


Expand All @@ -15,6 +17,11 @@ class Property:
constructor_template: ClassVar[Optional[str]] = None
_type_string: ClassVar[str]

python_name: str = field(init=False)

def __post_init__(self) -> None:
self.python_name = utils.snake_case(self.name)

def get_type_string(self) -> str:
""" Get a string representation of type that should be used when declaring this property """
if self.required:
Expand All @@ -31,13 +38,13 @@ def to_string(self) -> str:
default = None

if default is not None:
return f"{self.name}: {self.get_type_string()} = {self.default}"
return f"{self.python_name}: {self.get_type_string()} = {self.default}"
else:
return f"{self.name}: {self.get_type_string()}"
return f"{self.python_name}: {self.get_type_string()}"

def transform(self) -> str:
""" What it takes to turn this object into a native python type """
return self.name
return self.python_name

def constructor_from_dict(self, dict_name: str) -> str:
""" How to load this property from a dict (used in generated model from_dict function """
Expand All @@ -57,6 +64,7 @@ class StringProperty(Property):
_type_string: ClassVar[str] = "str"

def __post_init__(self) -> None:
super().__post_init__()
if self.default is not None:
self.default = f'"{self.default}"'

Expand Down Expand Up @@ -132,6 +140,7 @@ class EnumListProperty(Property):
constructor_template: ClassVar[str] = "enum_list_property.pyi"

def __post_init__(self) -> None:
super().__post_init__()
self.reference = Reference.from_ref(self.name)

def get_type_string(self) -> str:
Expand All @@ -149,6 +158,7 @@ class EnumProperty(Property):
reference: Reference = field(init=False)

def __post_init__(self) -> None:
super().__post_init__()
self.reference = Reference.from_ref(self.name)
inverse_values = {v: k for k, v in self.values.items()}
if self.default is not None:
Expand All @@ -163,7 +173,7 @@ def get_type_string(self) -> str:

def transform(self) -> str:
""" Output to the template, convert this Enum into a JSONable value """
return f"{self.name}.value"
return f"{self.python_name}.value"

def constructor_from_dict(self, dict_name: str) -> str:
""" How to load this property from a dict (used in generated model from_dict function """
Expand Down Expand Up @@ -204,7 +214,7 @@ def get_type_string(self) -> str:

def transform(self) -> str:
""" Convert this into a JSONable value """
return f"{self.name}.to_dict()"
return f"{self.python_name}.to_dict()"


@dataclass
Expand Down
6 changes: 3 additions & 3 deletions openapi_python_client/openapi_parser/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass
from typing import Dict

import stringcase
from .. import utils

class_overrides: Dict[str, Reference] = {}

Expand All @@ -21,9 +21,9 @@ class Reference:
def from_ref(ref: str) -> Reference:
""" Get a Reference from the openapi #/schemas/blahblah string """
ref_value = ref.split("/")[-1]
class_name = stringcase.pascalcase(ref_value)
class_name = utils.pascal_case(ref_value)

if class_name in class_overrides:
return class_overrides[class_name]

return Reference(class_name=class_name, module_name=stringcase.snakecase(ref_value),)
return Reference(class_name=class_name, module_name=utils.snake_case(ref_value),)
13 changes: 9 additions & 4 deletions openapi_python_client/templates/async_endpoint_module.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ from ..errors import ApiResponseError
{% for endpoint in collection.endpoints %}


async def {{ endpoint.name }}(
async def {{ endpoint.name | snakecase }}(
*,
{# Proper client based on whether or not the endpoint requires authentication #}
{% if endpoint.requires_security %}
Expand Down Expand Up @@ -42,7 +42,12 @@ async def {{ endpoint.name }}(
{% endfor %}
]:
""" {{ endpoint.description }} """
url = f"{client.base_url}{{ endpoint.path }}"
url = "{}{{ endpoint.path }}".format(
client.base_url
{%- for parameter in endpoint.path_parameters -%}
,{{parameter.name}}={{parameter.python_name}}
{%- endfor -%}
)

{% if endpoint.query_parameters %}
params = {
Expand All @@ -54,8 +59,8 @@ async def {{ endpoint.name }}(
}
{% for parameter in endpoint.query_parameters %}
{% if not parameter.required %}
if {{ parameter.name }} is not None:
params["{{ parameter.name }}"] = {{ parameter.transform() }}
if {{ parameter.python_name }} is not None:
params["{{ parameter.name }}"] = str({{ parameter.transform() }})
{% endif %}
{% endfor %}
{% endif %}
Expand Down
8 changes: 4 additions & 4 deletions openapi_python_client/templates/datetime_property.pyi
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{% if property.required %}
{{ property.name }} = datetime.fromisoformat(d["{{ property.name }}"])
{{ property.python_name }} = datetime.fromisoformat(d["{{ property.name }}"])
{% else %}
{{ property.name }} = None
if ({{ property.name }}_string := d.get("{{ property.name }}")) is not None:
{{ property.name }} = datetime.fromisoformat(cast(str, {{ property.name }}_string))
{{ property.python_name }} = None
if ({{ property.python_name }}_string := d.get("{{ property.name }}")) is not None:
{{ property.python_name }} = datetime.fromisoformat(cast(str, {{ property.python_name }}_string))
{% endif %}
13 changes: 9 additions & 4 deletions openapi_python_client/templates/endpoint_module.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ from ..errors import ApiResponseError
{% for endpoint in collection.endpoints %}


def {{ endpoint.name }}(
def {{ endpoint.name | snakecase }}(
*,
{# Proper client based on whether or not the endpoint requires authentication #}
{% if endpoint.requires_security %}
Expand Down Expand Up @@ -42,7 +42,12 @@ def {{ endpoint.name }}(
{% endfor %}
]:
""" {{ endpoint.description }} """
url = f"{client.base_url}{{ endpoint.path }}"
url = "{}{{ endpoint.path }}".format(
client.base_url
{%- for parameter in endpoint.path_parameters -%}
,{{parameter.name}}={{parameter.python_name}}
{%- endfor -%}
)

{% if endpoint.query_parameters %}
params = {
Expand All @@ -54,8 +59,8 @@ def {{ endpoint.name }}(
}
{% for parameter in endpoint.query_parameters %}
{% if not parameter.required %}
if {{ parameter.name }} is not None:
params["{{ parameter.name }}"] = {{ parameter.transform() }}
if {{ parameter.python_name }} is not None:
params["{{ parameter.name }}"] = str({{ parameter.transform() }})
{% endif %}
{% endfor %}
{% endif %}
Expand Down
6 changes: 3 additions & 3 deletions openapi_python_client/templates/enum_list_property.pyi
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{{ property.name }} = []
for {{ property.name }}_item in d.get("{{ property.name }}", []):
{{ property.name }}.append({{ property.reference.class_name }}({{ property.name }}_item))
{{ property.python_name }} = []
for {{ property.python_name }}_item in d.get("{{ property.name }}", []):
{{ property.python_name }}.append({{ property.reference.class_name }}({{ property.python_name }}_item))
8 changes: 4 additions & 4 deletions openapi_python_client/templates/model.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class {{ schema.reference.class_name }}:
"{{ property.name }}": self.{{ property.transform() }},
{% endfor %}
{% for property in schema.optional_properties %}
"{{ property.name }}": self.{{ property.transform() }} if self.{{ property.name }} is not None else None,
{% endfor %}
"{{ property.name }}": self.{{ property.transform() }} if self.{{ property.python_name }} is not None else None,
{% endfor %}
}

@staticmethod
Expand All @@ -33,12 +33,12 @@ class {{ schema.reference.class_name }}:
{% if property.constructor_template %}
{% include property.constructor_template %}
{% else %}
{{ property.name }} = {{ property.constructor_from_dict("d") }}
{{ property.python_name }} = {{ property.constructor_from_dict("d") }}
{% endif %}

{% endfor %}
return {{ schema.reference.class_name }}(
{% for property in schema.required_properties + schema.optional_properties %}
{{ property.name }}={{ property.name }},
{{ property.python_name }}={{ property.python_name }},
{% endfor %}
)
8 changes: 4 additions & 4 deletions openapi_python_client/templates/ref_property.pyi
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{% if property.required %}
{{ property.name }} = {{ property.reference.class_name }}.from_dict(d["{{ property.name }}"])
{{ property.python_name }} = {{ property.reference.class_name }}.from_dict(d["{{ property.name }}"])
{% else %}
{{ property.name }} = None
if ({{ property.name }}_data := d.get("{{ property.name }}")) is not None:
{{ property.name }} = {{ property.reference.class_name }}.from_dict(cast(Dict, {{ property.name }}_data))
{{ property.python_name }} = None
if ({{ property.python_name }}_data := d.get("{{ property.name }}")) is not None:
{{ property.python_name }} = {{ property.reference.class_name }}.from_dict(cast(Dict[str, Any], {{ property.python_name }}_data))
{% endif %}
6 changes: 3 additions & 3 deletions openapi_python_client/templates/reference_list_property.pyi
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{{ property.name }} = []
for {{ property.name }}_item in d.get("{{ property.name }}", []):
{{ property.name }}.append({{ property.reference.class_name }}.from_dict({{ property.name }}_item))
{{ property.python_name }} = []
for {{ property.python_name }}_item in d.get("{{ property.python_name }}", []):
{{ property.python_name }}.append({{ property.reference.class_name }}.from_dict({{ property.python_name }}_item))
13 changes: 13 additions & 0 deletions openapi_python_client/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import re

import stringcase


def snake_case(value: str) -> str:
value = re.sub(r"([A-Z]{2,})([A-Z][a-z]|[ -_]|$)", lambda m: m.group(1).title() + m.group(2), value.strip())
value = re.sub(r"(^|[ _-])([A-Z])", lambda m: m.group(1) + m.group(2).lower(), value)
return stringcase.snakecase(value)


def pascal_case(value: str) -> str:
return stringcase.pascalcase(value)
6 changes: 4 additions & 2 deletions tests/test_end_to_end/fastapi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
""" A FastAPI app used to create an OpenAPI document for end-to-end testing """
import json
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import List
Expand Down Expand Up @@ -43,9 +44,10 @@ class AModel(BaseModel):
a_list_of_enums: List[AnEnum]
a_list_of_strings: List[str]
a_list_of_objects: List[OtherModel]
aCamelDateTime: datetime


@test_router.get("/", response_model=List[AModel])
@test_router.get("/", response_model=List[AModel], operation_id="getUserList")
def get_list(statuses: List[AnEnum] = Query(...),):
""" Get users, filtered by statuses """
return
Expand All @@ -55,4 +57,4 @@ def get_list(statuses: List[AnEnum] = Query(...),):

if __name__ == "__main__":
path = Path(__file__).parent / "openapi.json"
path.write_text(json.dumps(app.openapi()))
path.write_text(json.dumps(app.openapi(), indent=4))
Loading