Skip to content

Commit e57aa6b

Browse files
committed
RestApiConnection.request: add option for expected status code
1 parent df72dfd commit e57aa6b

File tree

4 files changed

+61
-6
lines changed

4 files changed

+61
-6
lines changed

openeo/rest/connection.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import requests
1313
from deprecated import deprecated
1414
from openeo.rest import OpenEoClientException
15+
from openeo.util import ensure_list
1516
from requests import Response
1617
from requests.auth import HTTPBasicAuth, AuthBase
1718

@@ -67,7 +68,8 @@ def _merged_headers(self, headers: dict) -> dict:
6768
result.update(headers)
6869
return result
6970

70-
def request(self, method: str, path: str, headers: dict = None, auth: AuthBase = None, check_status=True, **kwargs):
71+
def request(self, method: str, path: str, headers: dict = None, auth: AuthBase = None,
72+
check_error=True, expected_status=None, **kwargs):
7173
"""Generic request send"""
7274
resp = self.session.request(
7375
method=method,
@@ -76,10 +78,12 @@ def request(self, method: str, path: str, headers: dict = None, auth: AuthBase =
7678
auth=auth or self.auth,
7779
**kwargs
7880
)
79-
if check_status:
80-
# TODO: option to specify the list/range of expected status codes?
81-
if resp.status_code >= 400:
82-
self._raise_api_error(resp)
81+
# Check for API errors and unexpected HTTP status codes as desired.
82+
status = resp.status_code
83+
if check_error and status >= 400:
84+
self._raise_api_error(resp)
85+
if expected_status and status not in ensure_list(expected_status):
86+
raise OpenEoClientException("Status code {s} is not expected {e}".format(s=status, e=expected_status))
8387
return resp
8488

8589
def _raise_api_error(self, response: requests.Response):

openeo/util.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@ def first_not_none(*args):
3434
raise ValueError("No not-None values given.")
3535

3636

37+
def ensure_list(x):
38+
"""Convert given data structure to a list."""
39+
try:
40+
return list(x)
41+
except TypeError:
42+
return [x]
43+
44+
3745
def get_temporal_extent(*args,
3846
start_date: Union[str, datetime, date] = None, end_date: Union[str, datetime, date] = None,
3947
extent: Union[list, tuple] = None,

tests/rest/test_connection.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pytest
44
import requests_mock
5+
from openeo.rest import OpenEoClientException
56

67
from openeo.rest.auth.auth import NullAuth, BearerAuth
78
from openeo.rest.connection import Connection, RestApiConnection, connect, OpenEoApiError
@@ -48,6 +49,35 @@ def text(request, context):
4849
conn.post("/foo", {}, headers={"X-Openeo-Bar": "XY123"})
4950

5051

52+
def test_rest_api_expected_status(requests_mock):
53+
conn = RestApiConnection(API_URL)
54+
requests_mock.get("https://oeo.net/foo", status_code=200, json={"o": "k"})
55+
# Expected status
56+
assert conn.get("/foo", expected_status=200).json() == {"o": "k"}
57+
assert conn.get("/foo", expected_status=[200, 201]).json() == {"o": "k"}
58+
# Unexpected status
59+
with pytest.raises(OpenEoClientException, match="Status code 200 is not expected 204"):
60+
conn.get("/foo", expected_status=204)
61+
with pytest.raises(OpenEoClientException, match=r"Status code 200 is not expected \[203, 204\]"):
62+
conn.get("/foo", expected_status=[203, 204])
63+
64+
65+
def test_rest_api_expected_status_with_error(requests_mock):
66+
conn = RestApiConnection(API_URL)
67+
requests_mock.get("https://oeo.net/bar", status_code=406, json={"code": "NoBar", "message": "no bar please"})
68+
# First check for API error by default
69+
with pytest.raises(OpenEoApiError, match=r"\[406\] NoBar: no bar please"):
70+
conn.get("/bar", expected_status=200)
71+
with pytest.raises(OpenEoApiError, match=r"\[406\] NoBar: no bar please"):
72+
conn.get("/bar", expected_status=[201, 202])
73+
# Don't check for error, just status
74+
conn.get("/bar", check_error=False, expected_status=406)
75+
with pytest.raises(OpenEoClientException, match="Status code 406 is not expected 302"):
76+
conn.get("/bar", check_error=False, expected_status=302)
77+
with pytest.raises(OpenEoClientException, match=r"Status code 406 is not expected \[302, 303\]"):
78+
conn.get("/bar", check_error=False, expected_status=[302, 303])
79+
80+
5181
def test_connection_with_session():
5282
session = mock.Mock()
5383
response = session.request.return_value

tests/test_util.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import pytest
66

7-
from openeo.util import first_not_none, get_temporal_extent, TimingLogger
7+
from openeo.util import first_not_none, get_temporal_extent, TimingLogger, ensure_list
88

99

1010
@pytest.mark.parametrize(['input', 'expected'], [
@@ -30,6 +30,19 @@ def test_first_not_none_failures():
3030
first_not_none(None, None)
3131

3232

33+
@pytest.mark.parametrize(["input", "expected"], [
34+
(None, [None]),
35+
(123, [123]),
36+
("abc", ["a", "b", "c"]),
37+
([1, 2, "three"], [1, 2, "three"]),
38+
((1, 2, "three"), [1, 2, "three"]),
39+
({1: "a", 2: "b"}, [1, 2]),
40+
({1, 2, 3, 3, 2}, [1, 2, 3]),
41+
])
42+
def test_ensure_list(input, expected):
43+
assert ensure_list(input) == expected
44+
45+
3346
def test_get_temporal_extent():
3447
assert get_temporal_extent("2019-03-15") == ("2019-03-15", None)
3548
assert get_temporal_extent("2019-03-15", "2019-10-11") == ("2019-03-15", "2019-10-11")

0 commit comments

Comments
 (0)