Skip to content

Tolerate RecursionError not being defined in Python<3.5 #1624

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions elasticsearch/_async/http_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import urllib3 # type: ignore

from ..compat import urlencode
from ..compat import reraise_exceptions, urlencode
from ..connection.base import Connection
from ..exceptions import (
ConnectionError,
Expand Down Expand Up @@ -304,7 +304,7 @@ async def perform_request(
duration = self.loop.time() - start

# We want to reraise a cancellation or recursion error.
except (asyncio.CancelledError, RecursionError):
except reraise_exceptions:
raise
except Exception as e:
self.log_request_fail(
Expand Down
14 changes: 14 additions & 0 deletions elasticsearch/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,22 @@ def to_bytes(x, encoding="ascii"):
from collections import Mapping


try:
reraise_exceptions = (RecursionError,)
except NameError:
reraise_exceptions = ()

try:
import asyncio

reraise_exceptions += (asyncio.CancelledError,)
except (ImportError, AttributeError):
pass


__all__ = [
"string_types",
"reraise_exceptions",
"quote_plus",
"quote",
"urlencode",
Expand Down
3 changes: 2 additions & 1 deletion elasticsearch/compat.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
# under the License.

import sys
from typing import Callable, Tuple, Union
from typing import Callable, Tuple, Type, Union

PY2: bool
string_types: Tuple[type, ...]

to_str: Callable[[Union[str, bytes]], str]
to_bytes: Callable[[Union[str, bytes]], bytes]
reraise_exceptions: Tuple[Type[Exception], ...]

if sys.version_info[0] == 2:
from itertools import imap as map
Expand Down
4 changes: 2 additions & 2 deletions elasticsearch/connection/http_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import time
import warnings

from ..compat import string_types, urlencode
from ..compat import reraise_exceptions, string_types, urlencode
from ..exceptions import (
ConnectionError,
ConnectionTimeout,
Expand Down Expand Up @@ -166,7 +166,7 @@ def perform_request(
response = self.session.send(prepared_request, **send_kwargs)
duration = time.time() - start
raw_data = response.content.decode("utf-8", "surrogatepass")
except RecursionError:
except reraise_exceptions:
raise
except Exception as e:
self.log_request_fail(
Expand Down
4 changes: 2 additions & 2 deletions elasticsearch/connection/http_urllib3.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from urllib3.exceptions import SSLError as UrllibSSLError # type: ignore
from urllib3.util.retry import Retry # type: ignore

from ..compat import urlencode
from ..compat import reraise_exceptions, urlencode
from ..exceptions import (
ConnectionError,
ConnectionTimeout,
Expand Down Expand Up @@ -253,7 +253,7 @@ def perform_request(
)
duration = time.time() - start
raw_data = response.data.decode("utf-8", "surrogatepass")
except RecursionError:
except reraise_exceptions:
raise
except Exception as e:
self.log_request_fail(
Expand Down
21 changes: 21 additions & 0 deletions test_elasticsearch/test_async/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from multidict import CIMultiDict

from elasticsearch import AIOHttpConnection, __versionstr__
from elasticsearch.compat import reraise_exceptions
from elasticsearch.exceptions import ConnectionError

pytestmark = pytest.mark.asyncio

Expand Down Expand Up @@ -318,6 +320,20 @@ async def test_surrogatepass_into_bytes(self):
status, headers, data = await con.perform_request("GET", "/")
assert u"你好\uda6a" == data

@pytest.mark.parametrize("exception_cls", reraise_exceptions)
async def test_recursion_error_reraised(self, exception_cls):
conn = AIOHttpConnection()

def request_raise(*_, **__):
raise exception_cls("Wasn't modified!")

await conn._create_aiohttp_session()
conn.session.request = request_raise

with pytest.raises(exception_cls) as e:
await conn.perform_request("GET", "/")
assert str(e.value) == "Wasn't modified!"


class TestConnectionHttpbin:
"""Tests the HTTP connection implementations against a live server E2E"""
Expand Down Expand Up @@ -389,3 +405,8 @@ async def test_aiohttp_connection(self):
"Header2": "value2",
"User-Agent": user_agent,
}

async def test_aiohttp_connection_error(self):
conn = AIOHttpConnection("not.a.host.name")
with pytest.raises(ConnectionError):
await conn.perform_request("GET", "/")
42 changes: 42 additions & 0 deletions test_elasticsearch/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@
from urllib3._collections import HTTPHeaderDict

from elasticsearch import __versionstr__
from elasticsearch.compat import reraise_exceptions
from elasticsearch.connection import (
Connection,
RequestsHttpConnection,
Urllib3HttpConnection,
)
from elasticsearch.exceptions import (
ConflictError,
ConnectionError,
NotFoundError,
RequestError,
TransportError,
Expand Down Expand Up @@ -466,6 +468,21 @@ def test_surrogatepass_into_bytes(self):
status, headers, data = con.perform_request("GET", "/")
self.assertEqual(u"你好\uda6a", data)

@pytest.mark.skipif(
not reraise_exceptions, reason="RecursionError isn't defined in Python <3.5"
)
def test_recursion_error_reraised(self):
conn = Urllib3HttpConnection()

def urlopen_raise(*_, **__):
raise RecursionError("Wasn't modified!")

conn.pool.urlopen = urlopen_raise

with pytest.raises(RecursionError) as e:
conn.perform_request("GET", "/")
assert str(e.value) == "Wasn't modified!"


class TestRequestsConnection(TestCase):
def _get_mock_connection(
Expand Down Expand Up @@ -868,6 +885,21 @@ def test_surrogatepass_into_bytes(self):
status, headers, data = con.perform_request("GET", "/")
self.assertEqual(u"你好\uda6a", data)

@pytest.mark.skipif(
not reraise_exceptions, reason="RecursionError isn't defined in Python <3.5"
)
def test_recursion_error_reraised(self):
conn = RequestsHttpConnection()

def send_raise(*_, **__):
raise RecursionError("Wasn't modified!")

conn.session.send = send_raise

with pytest.raises(RecursionError) as e:
conn.perform_request("GET", "/")
assert str(e.value) == "Wasn't modified!"


class TestConnectionHttpbin:
"""Tests the HTTP connection implementations against a live server E2E"""
Expand Down Expand Up @@ -942,6 +974,11 @@ def test_urllib3_connection(self):
"User-Agent": user_agent,
}

def test_urllib3_connection_error(self):
conn = Urllib3HttpConnection("not.a.host.name")
with pytest.raises(ConnectionError):
conn.perform_request("GET", "/")

def test_requests_connection(self):
# Defaults
conn = RequestsHttpConnection("httpbin.org", port=443, use_ssl=True)
Expand Down Expand Up @@ -1003,3 +1040,8 @@ def test_requests_connection(self):
"Header2": "value2",
"User-Agent": user_agent,
}

def test_requests_connection_error(self):
conn = RequestsHttpConnection("not.a.host.name")
with pytest.raises(ConnectionError):
conn.perform_request("GET", "/")