diff --git a/openapi_core/unmarshalling/schemas/factories.py b/openapi_core/unmarshalling/schemas/factories.py index 89db3047..66184cba 100644 --- a/openapi_core/unmarshalling/schemas/factories.py +++ b/openapi_core/unmarshalling/schemas/factories.py @@ -1,6 +1,7 @@ import warnings from typing import Any from typing import Dict +from typing import Iterable from typing import Optional from typing import Type from typing import Union @@ -30,6 +31,9 @@ from openapi_core.unmarshalling.schemas.unmarshallers import ( IntegerUnmarshaller, ) +from openapi_core.unmarshalling.schemas.unmarshallers import ( + MultiTypeUnmarshaller, +) from openapi_core.unmarshalling.schemas.unmarshallers import NullUnmarshaller from openapi_core.unmarshalling.schemas.unmarshallers import NumberUnmarshaller from openapi_core.unmarshalling.schemas.unmarshallers import ObjectUnmarshaller @@ -89,6 +93,12 @@ def create( formatter = self.custom_formatters.get(schema_format) schema_type = type_override or schema.getkey("type", "any") + if isinstance(schema_type, Iterable) and not isinstance( + schema_type, str + ): + return MultiTypeUnmarshaller( + schema, validator, formatter, self, context=self.context + ) if schema_type in self.COMPLEX_UNMARSHALLERS: complex_klass = self.COMPLEX_UNMARSHALLERS[schema_type] return complex_klass( diff --git a/openapi_core/unmarshalling/schemas/unmarshallers.py b/openapi_core/unmarshalling/schemas/unmarshallers.py index 3a738440..c2704a5c 100644 --- a/openapi_core/unmarshalling/schemas/unmarshallers.py +++ b/openapi_core/unmarshalling/schemas/unmarshallers.py @@ -312,6 +312,27 @@ def _unmarshal_object(self, value: Any) -> Any: return properties +class MultiTypeUnmarshaller(ComplexUnmarshaller): + @property + def types_unmarshallers(self) -> List["BaseSchemaUnmarshaller"]: + types = self.schema.getkey("type", ["any"]) + unmarshaller = partial(self.unmarshallers_factory.create, self.schema) + return list(map(unmarshaller, types)) + + def unmarshal(self, value: Any) -> Any: + for unmarshaller in self.types_unmarshallers: + # validate with validator of formatter (usualy type validator) + try: + unmarshaller._formatter_validate(value) + except ValidateError: + continue + else: + return unmarshaller(value) + + log.warning("failed to unmarshal multi type") + return value + + class AnyUnmarshaller(ComplexUnmarshaller): SCHEMA_TYPES_ORDER = [ diff --git a/tests/unit/unmarshalling/test_unmarshal.py b/tests/unit/unmarshalling/test_unmarshal.py index 224b00a7..3ce50db4 100644 --- a/tests/unit/unmarshalling/test_unmarshal.py +++ b/tests/unit/unmarshalling/test_unmarshal.py @@ -835,6 +835,34 @@ def test_additional_properties_list(self, unmarshaller_factory): "user_ids": [1, 2, 3, 4], } + @pytest.mark.xfail(message="None and NOTSET should be distinguished") + def test_null_not_supported(self, unmarshaller_factory): + schema = {"type": "null"} + spec = Spec.from_dict(schema) + + with pytest.raises(InvalidSchemaValue): + unmarshaller_factory(spec)(None) + + @pytest.mark.parametrize( + "types,value", + [ + (["string", "null"], "string"), + (["number", "null"], 2), + (["number", "null"], 3.14), + (["boolean", "null"], True), + (["array", "null"], [1, 2]), + (["object", "null"], {}), + ], + ) + def test_nultiple_types_not_supported( + self, unmarshaller_factory, types, value + ): + schema = {"type": types} + spec = Spec.from_dict(schema) + + with pytest.raises(TypeError): + unmarshaller_factory(spec)(value) + class TestOAS31SchemaUnmarshallerCall: @pytest.fixture @@ -856,3 +884,40 @@ def test_null_invalid(self, unmarshaller_factory, value): with pytest.raises(InvalidSchemaValue): unmarshaller_factory(spec)(value) + + @pytest.mark.parametrize( + "types,value", + [ + (["string", "null"], "string"), + (["number", "null"], 2), + (["number", "null"], 3.14), + (["boolean", "null"], True), + (["array", "null"], [1, 2]), + (["object", "null"], {}), + ], + ) + def test_nultiple_types(self, unmarshaller_factory, types, value): + schema = {"type": types} + spec = Spec.from_dict(schema) + + result = unmarshaller_factory(spec)(value) + + assert result == value + + @pytest.mark.parametrize( + "types,value", + [ + (["string", "null"], 2), + (["number", "null"], "string"), + (["number", "null"], True), + (["boolean", "null"], 3.14), + (["array", "null"], {}), + (["object", "null"], [1, 2]), + ], + ) + def test_nultiple_types_invalid(self, unmarshaller_factory, types, value): + schema = {"type": types} + spec = Spec.from_dict(schema) + + with pytest.raises(InvalidSchemaValue): + unmarshaller_factory(spec)(value)