diff --git a/mergin/client.py b/mergin/client.py index 6456c1b..9e32310 100644 --- a/mergin/client.py +++ b/mergin/client.py @@ -94,7 +94,7 @@ def __init__( proxy_config=None, ): self.url = url if url is not None else MerginClient.default_url() - self._auth_params = None + self._auth_params = {} self._auth_session = None self._user_info = None self._server_type = None @@ -192,36 +192,32 @@ def user_agent_info(self): system_version = platform.mac_ver()[0] return f"{self.client_version} ({platform.system()}/{system_version})" - def _check_token(f): - """Wrapper for creating/renewing authorization token.""" - - def wrapper(self, *args): - if self._auth_params: - if self._auth_session: - # Refresh auth token if it expired or will expire very soon - delta = self._auth_session["expire"] - datetime.now(timezone.utc) - if delta.total_seconds() < 5: - self.log.info("Token has expired - refreshing...") - if self._auth_params.get("login", None) and self._auth_params.get("password", None): - self.log.info("Token has expired - refreshing...") - self.login(self._auth_params["login"], self._auth_params["password"]) - else: - raise AuthTokenExpiredError("Token has expired - please re-login") - else: - # Create a new authorization token - self.log.info(f"No token - login user: {self._auth_params['login']}") - if self._auth_params.get("login", None) and self._auth_params.get("password", None): - self.login(self._auth_params["login"], self._auth_params["password"]) - else: - raise ClientError("Missing login or password") - - return f(self, *args) + def validate_auth(self): + """Validate that client has valid auth token or can be logged in.""" - return wrapper + if self._auth_session: + # Refresh auth token if it expired or will expire very soon + delta = self._auth_session["expire"] - datetime.now(timezone.utc) + if delta.total_seconds() < 5: + self.log.info("Token has expired - refreshing...") + if self._auth_params.get("login", None) and self._auth_params.get("password", None): + self.log.info("Token has expired - refreshing...") + self.login(self._auth_params["login"], self._auth_params["password"]) + else: + raise AuthTokenExpiredError("Token has expired - please re-login") + else: + # Create a new authorization token + self.log.info(f"No token - login user: {self._auth_params.get('login', None)}") + if self._auth_params.get("login", None) and self._auth_params.get("password", None): + self.login(self._auth_params["login"], self._auth_params["password"]) + else: + raise ClientError("Missing login or password") - @_check_token - def _do_request(self, request): + def _do_request(self, request, validate_auth=True): """General server request method.""" + if validate_auth: + self.validate_auth() + if self._auth_session: request.add_header("Authorization", self._auth_session["token"]) request.add_header("User-Agent", self.user_agent_info()) @@ -263,31 +259,31 @@ def _do_request(self, request): # e.g. when DNS resolution fails (no internet connection?) raise ClientError("Error requesting " + request.full_url + ": " + str(e)) - def get(self, path, data=None, headers={}): + def get(self, path, data=None, headers={}, validate_auth=True): url = urllib.parse.urljoin(self.url, urllib.parse.quote(path)) if data: url += "?" + urllib.parse.urlencode(data) request = urllib.request.Request(url, headers=headers) - return self._do_request(request) + return self._do_request(request, validate_auth=validate_auth) - def post(self, path, data=None, headers={}): + def post(self, path, data=None, headers={}, validate_auth=True): url = urllib.parse.urljoin(self.url, urllib.parse.quote(path)) if headers.get("Content-Type", None) == "application/json": data = json.dumps(data, cls=DateTimeEncoder).encode("utf-8") request = urllib.request.Request(url, data, headers, method="POST") - return self._do_request(request) + return self._do_request(request, validate_auth=validate_auth) - def patch(self, path, data=None, headers={}): + def patch(self, path, data=None, headers={}, validate_auth=True): url = urllib.parse.urljoin(self.url, urllib.parse.quote(path)) if headers.get("Content-Type", None) == "application/json": data = json.dumps(data, cls=DateTimeEncoder).encode("utf-8") request = urllib.request.Request(url, data, headers, method="PATCH") - return self._do_request(request) + return self._do_request(request, validate_auth=validate_auth) - def delete(self, path): + def delete(self, path, validate_auth=True): url = urllib.parse.urljoin(self.url, urllib.parse.quote(path)) request = urllib.request.Request(url, method="DELETE") - return self._do_request(request) + return self._do_request(request, validate_auth=validate_auth) def login(self, login, password): """ @@ -303,26 +299,16 @@ def login(self, login, password): self._auth_session = None self.log.info(f"Going to log in user {login}") try: - self._auth_params = params - url = urllib.parse.urljoin(self.url, urllib.parse.quote("/v1/auth/login")) - data = json.dumps(self._auth_params, cls=DateTimeEncoder).encode("utf-8") - request = urllib.request.Request(url, data, {"Content-Type": "application/json"}, method="POST") - request.add_header("User-Agent", self.user_agent_info()) - resp = self.opener.open(request) + resp = self.post( + "/v1/auth/login", data=params, headers={"Content-Type": "application/json"}, validate_auth=False + ) data = json.load(resp) session = data["session"] - except urllib.error.HTTPError as e: - if e.headers.get("Content-Type", "") == "application/problem+json": - info = json.load(e) - self.log.info(f"Login problem: {info.get('detail')}") - raise LoginError(info.get("detail")) - self.log.info(f"Login problem: {e.read().decode('utf-8')}") - raise LoginError(e.read().decode("utf-8")) - except urllib.error.URLError as e: - # e.g. when DNS resolution fails (no internet connection?) - raise ClientError("failure reason: " + str(e.reason)) + except ClientError as e: + self.log.info(f"Login problem: {e.detail}") + raise LoginError(e.detail) self._auth_session = { - "token": "Bearer %s" % session["token"], + "token": f"Bearer {session['token']}", "expire": dateutil.parser.parse(session["expire"]), } self._user_info = {"username": data["username"]} @@ -367,7 +353,7 @@ def server_type(self): """ if not self._server_type: try: - resp = self.get("/config") + resp = self.get("/config", validate_auth=False) config = json.load(resp) if config["server_type"] == "ce": self._server_type = ServerType.CE @@ -389,7 +375,7 @@ def server_version(self): """ if self._server_version is None: try: - resp = self.get("/config") + resp = self.get("/config", validate_auth=False) config = json.load(resp) self._server_version = config["version"] except (ClientError, KeyError): @@ -1386,7 +1372,7 @@ def remove_project_collaborator(self, project_id: str, user_id: int): def server_config(self) -> dict: """Get server configuration as dictionary.""" - response = self.get("/config") + response = self.get("/config", validate_auth=False) return json.load(response) def send_logs( diff --git a/mergin/test/test_client.py b/mergin/test/test_client.py index b5de8a6..214e010 100644 --- a/mergin/test/test_client.py +++ b/mergin/test/test_client.py @@ -5,7 +5,7 @@ import tempfile import subprocess import shutil -from datetime import datetime, timedelta, date +from datetime import datetime, timedelta, date, timezone import pytest import pytz import sqlite3 @@ -14,6 +14,7 @@ from .. import InvalidProject from ..client import ( MerginClient, + AuthTokenExpiredError, ClientError, MerginProject, LoginError, @@ -2888,8 +2889,7 @@ def test_mc_without_login(): with pytest.raises(ClientError) as e: mc.workspaces_list() - assert e.value.http_error == 401 - assert e.value.detail == '"Authentication information is missing or invalid."\n' + assert e.value.detail == "Missing login or password" def test_do_request_error_handling(mc: MerginClient): @@ -2911,3 +2911,61 @@ def test_do_request_error_handling(mc: MerginClient): assert e.value.http_error == 400 assert "Passwords must be at least 8 characters long." in e.value.detail + + +def test_validate_auth(mc: MerginClient): + """Test validate authentication under different scenarios.""" + + # ----- Client without authentication ----- + mc_not_auth = MerginClient(SERVER_URL) + + with pytest.raises(ClientError) as e: + mc_not_auth.validate_auth() + + assert e.value.detail == "Missing login or password" + + # ----- Client with token ----- + # create a client with valid auth token based on other MerginClient instance, but not with username/password + mc_auth_token = MerginClient(SERVER_URL, auth_token=mc._auth_session["token"]) + + # this should pass and not raise an error + mc_auth_token.validate_auth() + + # manually set expire date to the past to simulate expired token + mc_auth_token._auth_session["expire"] = datetime.now(timezone.utc) - timedelta(days=1) + + # check that this raises an error + with pytest.raises(AuthTokenExpiredError): + mc_auth_token.validate_auth() + + # ----- Client with token and username/password ----- + # create a client with valid auth token based on other MerginClient instance with username/password that allows relogin if the token is expired + mc_auth_token_login = MerginClient( + SERVER_URL, auth_token=mc._auth_session["token"], login=API_USER, password=USER_PWD + ) + + # this should pass and not raise an error + mc_auth_token_login.validate_auth() + + # manually set expire date to the past to simulate expired token + mc_auth_token_login._auth_session["expire"] = datetime.now(timezone.utc) - timedelta(days=1) + + # this should pass and not raise an error, as the client is able to re-login + mc_auth_token_login.validate_auth() + + # ----- Client with token and username/WRONG password ----- + # create a client with valid auth token based on other MerginClient instance with username and WRONG password + # that does NOT allow relogin if the token is expired + mc_auth_token_login_wrong_password = MerginClient( + SERVER_URL, auth_token=mc._auth_session["token"], login=API_USER, password="WRONG_PASSWORD" + ) + + # this should pass and not raise an error + mc_auth_token_login_wrong_password.validate_auth() + + # manually set expire date to the past to simulate expired token + mc_auth_token_login_wrong_password._auth_session["expire"] = datetime.now(timezone.utc) - timedelta(days=1) + + # this should pass and not raise an error, as the client is able to re-login + with pytest.raises(LoginError): + mc_auth_token_login_wrong_password.validate_auth()