diff --git a/.changeset/allow_case_sensitive_enum_values.md b/.changeset/allow_case_sensitive_enum_values.md new file mode 100644 index 000000000..514f042f4 --- /dev/null +++ b/.changeset/allow_case_sensitive_enum_values.md @@ -0,0 +1,13 @@ +--- +default: major +--- + +# Allow case sensitive enum values + +#725 by @expobrain + +Added setting `case_sensitive_enums` to allow case sensitive enum values in the generated code. + +This solve the issue in #587 . + +Also, to avoid collisions, changes the prefix for values not starting with alphanumeric characters from `VALUE_` to `LITERAL_`. diff --git a/openapi_python_client/config.py b/openapi_python_client/config.py index 740e06309..7318ef04e 100644 --- a/openapi_python_client/config.py +++ b/openapi_python_client/config.py @@ -43,6 +43,7 @@ class ConfigFile(BaseModel): post_hooks: Optional[List[str]] = None field_prefix: str = "field_" http_timeout: int = 5 + case_sensitive_enums: bool = False @staticmethod def load_from_path(path: Path) -> "ConfigFile": @@ -75,6 +76,7 @@ class Config: content_type_overrides: Dict[str, str] overwrite: bool output_path: Optional[Path] + case_sensitive_enums: bool @staticmethod def from_sources( @@ -113,5 +115,6 @@ def from_sources( file_encoding=file_encoding, overwrite=overwrite, output_path=output_path, + case_sensitive_enums=config_file.case_sensitive_enums, ) return config diff --git a/openapi_python_client/parser/properties/enum_property.py b/openapi_python_client/parser/properties/enum_property.py index 0f0db0d61..e9cb72c6e 100644 --- a/openapi_python_client/parser/properties/enum_property.py +++ b/openapi_python_client/parser/properties/enum_property.py @@ -2,7 +2,7 @@ __all__ = ["EnumProperty"] -from typing import Any, ClassVar, List, Union, cast +from typing import Any, ClassVar, List, Sequence, Union, cast from attr import evolve from attrs import define @@ -121,7 +121,7 @@ def build( # noqa: PLR0911 if parent_name: class_name = f"{utils.pascal_case(parent_name)}{utils.pascal_case(class_name)}" class_info = Class.from_string(string=class_name, config=config) - values = EnumProperty.values_from_list(value_list) + values = EnumProperty.values_from_list(value_list, case_sensitive_enums=config.case_sensitive_enums) if class_info.name in schemas.classes_by_name: existing = schemas.classes_by_name[class_info.name] @@ -183,24 +183,30 @@ def get_imports(self, *, prefix: str) -> set[str]: return imports @staticmethod - def values_from_list(values: list[str] | list[int]) -> dict[str, ValueType]: + def values_from_list( + values: Sequence[str] | Sequence[int], case_sensitive_enums: bool = False + ) -> dict[str, ValueType]: """Convert a list of values into dict of {name: value}, where value can sometimes be None""" output: dict[str, ValueType] = {} - for i, value in enumerate(values): - value = cast(Union[str, int], value) + for value in values: if isinstance(value, int): if value < 0: output[f"VALUE_NEGATIVE_{-value}"] = value else: output[f"VALUE_{value}"] = value continue - if value and value[0].isalpha(): - key = value.upper() + + if case_sensitive_enums: + sanitized_key = utils.case_sensitive_snake_case(value) else: - key = f"VALUE_{i}" - if key in output: - raise ValueError(f"Duplicate key {key} in Enum") - sanitized_key = utils.snake_case(key).upper() + sanitized_key = utils.snake_case(value.lower()).upper() + if not value or not value[0].isalpha(): + sanitized_key = f"LITERAL_{sanitized_key}" + + if sanitized_key in output: + raise ValueError(f"Duplicate key {sanitized_key} in Enum") + output[sanitized_key] = utils.remove_string_escapes(value) + return output diff --git a/openapi_python_client/utils.py b/openapi_python_client/utils.py index 22a7bcfa8..682203f87 100644 --- a/openapi_python_client/utils.py +++ b/openapi_python_client/utils.py @@ -77,10 +77,19 @@ def fix_reserved_words(value: str) -> str: return value +def case_sensitive_snake_case(value: str) -> str: + """Converts to snake_case, but preserves capitalization of acronyms""" + words = split_words(sanitize(value)) + value = "_".join(words) + + return value + + def snake_case(value: str) -> str: """Converts to snake_case""" - words = split_words(sanitize(value)) - return "_".join(words).lower() + value = case_sensitive_snake_case(value).lower() + + return value def pascal_case(value: str) -> str: diff --git a/tests/test_config.py b/tests/test_config.py index 2e39bbf4e..2776b1163 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -39,6 +39,7 @@ def test_load_from_path(tmp_path: Path, filename, dump, relative): "project_name_override": "project-name", "package_name_override": "package_name", "package_version_override": "package_version", + "case_sensitive_enums": True, } yml_file.write_text(dump(data)) @@ -49,3 +50,4 @@ def test_load_from_path(tmp_path: Path, filename, dump, relative): assert config.project_name_override == "project-name" assert config.package_name_override == "package_name" assert config.package_version_override == "package_version" + assert config.case_sensitive_enums is True diff --git a/tests/test_parser/test_properties/test_init.py b/tests/test_parser/test_properties/test_init.py index a30059a93..54f7feac0 100644 --- a/tests/test_parser/test_properties/test_init.py +++ b/tests/test_parser/test_properties/test_init.py @@ -355,13 +355,32 @@ def test_values_from_list(self): assert result == { "ABC": "abc", - "VALUE_1": "123", + "LITERAL_123": "123", "A23": "a23", - "VALUE_3": "1bc", + "LITERAL_1BC": "1bc", "VALUE_4": 4, "VALUE_NEGATIVE_3": -3, "A_THING_WITH_SPACES": "a Thing WIth spaces", - "VALUE_7": "", + "LITERAL_": "", + } + + def test_values_from_list_with_case_sesitive(self): + from openapi_python_client.parser.properties import EnumProperty + + data = ["abc", "Abc", "123", "a23", "1bc", 4, -3, "a Thing WIth spaces", ""] + + result = EnumProperty.values_from_list(data, case_sensitive_enums=True) + + assert result == { + "abc": "abc", + "Abc": "Abc", + "LITERAL_123": "123", + "a23": "a23", + "LITERAL_1bc": "1bc", + "VALUE_4": 4, + "VALUE_NEGATIVE_3": -3, + "a_Thing_W_Ith_spaces": "a Thing WIth spaces", + "LITERAL_": "", } def test_values_from_list_duplicate(self):