Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 39 additions & 16 deletions connexion/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import contextlib
import io
import json
import os
import typing as t
import urllib.parse
Expand All @@ -13,9 +14,11 @@

import requests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the imports are grouped (1) python standard, (2) third party and (3) connexion. FWIW maybe it's worth sorting within categories (1) and (2) here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a sorting within categories: At first normal imports sorted alphabetically, then from imports sorted alphabetically.

If I change the order isort will complain.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining, I didn't spot that pattern, never mind!

import yaml
from jsonschema import Draft4Validator, RefResolver
from jsonschema.exceptions import RefResolutionError, ValidationError # noqa
from jsonschema import Draft4Validator
from jsonschema.exceptions import ValidationError
from jsonschema.validators import extend
from referencing import Registry, Resource
from referencing.jsonschema import DRAFT4

from .utils import deep_get

Expand Down Expand Up @@ -62,12 +65,27 @@ def __call__(self, uri):
return yaml.load(fh, ExtendedSafeLoader)


handlers = {
"http": URLHandler(),
"https": URLHandler(),
"file": FileHandler(),
"": FileHandler(),
}
def resource_from_spec(spec: t.Dict[str, t.Any]) -> Resource:
"""Create a `referencing.Resource` from a schema specification."""
return Resource.from_contents(spec, default_specification=DRAFT4)


def retrieve(uri: str) -> Resource:
"""Retrieve a resource from a URI.

This function is passed to the `referencing.Registry`,
which calls it any URI is not present in the registry is accessed."""
parsed = urllib.parse.urlsplit(uri)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you please consider adding pydoc for this new function?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

if parsed.scheme in ("http", "https"):
content = URLHandler()(uri)
elif parsed.scheme in ("file", ""):
content = FileHandler()(uri)
else: # pragma: no cover
# Default branch from jsonschema.RefResolver.resolve_remote()
# for backwards compatibility.
with urllib.request.urlopen(uri) as url:
content = json.loads(url.read().decode("utf-8"))
return resource_from_spec(content)


def resolve_refs(spec, store=None, base_uri=""):
Expand All @@ -78,32 +96,37 @@ def resolve_refs(spec, store=None, base_uri=""):
"""
spec = deepcopy(spec)
store = store or {}
resolver = RefResolver(base_uri, spec, store, handlers=handlers)
registry = Registry(retrieve=retrieve).with_resources(
(
(base_uri, resource_from_spec(spec)),
*((key, resource_from_spec(value)) for key, value in store.items()),
)
)

def _do_resolve(node):
def _do_resolve(node, resolver):
if isinstance(node, Mapping) and "$ref" in node:
path = node["$ref"][2:].split("/")
try:
# resolve known references
retrieved = deep_get(spec, path)
node.update(retrieved)
if isinstance(retrieved, Mapping) and "$ref" in retrieved:
node = _do_resolve(node)
node = _do_resolve(node, resolver)
node.pop("$ref", None)
return node
except KeyError:
# resolve external references
with resolver.resolving(node["$ref"]) as resolved:
return _do_resolve(resolved)
resolved = resolver.lookup(node["$ref"])
return _do_resolve(resolved.contents, resolved.resolver)
elif isinstance(node, Mapping):
for k, v in node.items():
node[k] = _do_resolve(v)
node[k] = _do_resolve(v, resolver)
elif isinstance(node, (list, tuple)):
for i, _ in enumerate(node):
node[i] = _do_resolve(node[i])
node[i] = _do_resolve(node[i], resolver)
return node

res = _do_resolve(spec)
res = _do_resolve(spec, registry.resolver(base_uri))
return res


Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Jinja2 = ">= 3.0.0"
python-multipart = ">= 0.0.15"
PyYAML = ">= 5.1"
requests = ">= 2.27"
referencing = ">= 0.12.0"
starlette = ">= 0.35"
typing-extensions = ">= 4.6.1"
werkzeug = ">= 2.2.1"
Expand Down
9 changes: 5 additions & 4 deletions tests/test_references.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from unittest import mock

import pytest
from connexion.json_schema import RefResolutionError, resolve_refs
from connexion.json_schema import resolve_refs
from connexion.jsonifier import Jsonifier
from referencing.exceptions import Unresolvable

DEFINITIONS = {
"new_stack": {
Expand Down Expand Up @@ -50,7 +51,7 @@ def test_non_existent_reference(api):
}
]
}
with pytest.raises(RefResolutionError) as exc_info: # type: py.code.ExceptionInfo
with pytest.raises(Unresolvable) as exc_info: # type: py.code.ExceptionInfo
resolve_refs(op_spec, {})

exception = exc_info.value
Expand All @@ -69,7 +70,7 @@ def test_invalid_reference(api):
]
}

with pytest.raises(RefResolutionError) as exc_info: # type: py.code.ExceptionInfo
with pytest.raises(Unresolvable) as exc_info: # type: py.code.ExceptionInfo
resolve_refs(
op_spec, {"definitions": DEFINITIONS, "parameters": PARAMETER_DEFINITIONS}
)
Expand All @@ -84,7 +85,7 @@ def test_resolve_invalid_reference(api):
"parameters": [{"$ref": "/parameters/fail"}],
}

with pytest.raises(RefResolutionError) as exc_info:
with pytest.raises(Unresolvable) as exc_info:
resolve_refs(op_spec, {"parameters": PARAMETER_DEFINITIONS})

exception = exc_info.value
Expand Down