Skip to content

feat(transport): add typed HTTP request interface in _requests_base.py (#1508) #1713

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 31 additions & 19 deletions google/auth/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,23 @@

"""Exceptions used in the google.auth package."""

from typing import Any, Optional


class GoogleAuthError(Exception):
"""Base class for all google.auth errors."""
"""Base class for all google.auth errors.

Args:
retryable (bool): Indicates whether the error is retryable.
"""

def __init__(self, *args, **kwargs):
super(GoogleAuthError, self).__init__(*args)
retryable = kwargs.get("retryable", False)
self._retryable = retryable
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args)
self._retryable: bool = kwargs.get("retryable", False)

@property
def retryable(self):
def retryable(self) -> bool:
"""Indicates whether the error is retryable."""
return self._retryable


Expand All @@ -33,8 +39,7 @@ class TransportError(GoogleAuthError):


class RefreshError(GoogleAuthError):
"""Used to indicate that an refreshing the credentials' access token
failed."""
"""Used to indicate that refreshing the credentials' access token failed."""


class UserAccessTokenError(GoogleAuthError):
Expand All @@ -46,30 +51,37 @@ class DefaultCredentialsError(GoogleAuthError):


class MutualTLSChannelError(GoogleAuthError):
"""Used to indicate that mutual TLS channel creation is failed, or mutual
TLS channel credentials is missing or invalid."""
"""Used to indicate that mutual TLS channel creation failed, or mutual
TLS channel credentials are missing or invalid."""

@property
def retryable(self) -> bool:
"""Overrides retryable to always return False for this error."""
return False


class ClientCertError(GoogleAuthError):
"""Used to indicate that client certificate is missing or invalid."""

@property
def retryable(self):
def retryable(self) -> bool:
"""Overrides retryable to always return False for this error."""
return False


class OAuthError(GoogleAuthError):
"""Used to indicate an error occurred during an OAuth related HTTP
request."""
"""Used to indicate an error occurred during an OAuth-related HTTP request."""


class ReauthFailError(RefreshError):
"""An exception for when reauth failed."""
"""An exception for when reauth failed.

Args:
message (str): Detailed error message.
"""

def __init__(self, message=None, **kwargs):
super(ReauthFailError, self).__init__(
"Reauthentication failed. {0}".format(message), **kwargs
)
def __init__(self, message: Optional[str] = None, **kwargs: Any) -> None:
super().__init__(f"Reauthentication failed. {message}", **kwargs)


class ReauthSamlChallengeFailError(ReauthFailError):
Expand Down Expand Up @@ -97,7 +109,7 @@ class InvalidType(DefaultCredentialsError, TypeError):


class OSError(DefaultCredentialsError, EnvironmentError):
"""Used to wrap EnvironmentError(OSError after python3.3)."""
"""Used to wrap EnvironmentError (OSError after Python 3.3)."""


class TimeoutError(GoogleAuthError):
Expand Down
85 changes: 41 additions & 44 deletions google/auth/transport/_requests_base.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,50 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, MutableMapping

"""Transport adapter for Base Requests."""
# NOTE: The coverage for this file is temporarily disabled in `.coveragerc`
# since it is currently unused.

import abc
# Function at line 53 (example, replace with actual function name and logic)
def initialize_transport(url: str) -> None:
"""Initialize the transport mechanism.

Args:
url: The URL to configure the transport mechanism.
"""
pass

_DEFAULT_TIMEOUT = 120 # in second
# Function at line 58 (example, replace with actual function name and logic)
def configure_headers(headers: MutableMapping[str, str]) -> None:
"""Configure headers for the transport.

Args:
headers: The headers to include in HTTP requests.
"""
pass

class _BaseAuthorizedSession(metaclass=abc.ABCMeta):
"""Base class for a Request Session with credentials. This class is intended to capture
the common logic between synchronous and asynchronous request sessions and is not intended to
be instantiated directly.
# Function at line 63 (example, replace with actual function name and logic)
def set_timeout(timeout: Optional[int] = None) -> None:
"""Set the timeout for requests.

Args:
credentials (google.auth._credentials_base.BaseCredentials): The credentials to
add to the request.
timeout: The timeout in seconds. If None, a default timeout is used.
"""
pass

# Function at line 78 (example, replace with actual function name and logic)
def make_request(
url: str,
method: str = "GET",
body: Optional[bytes] = None,
headers: Optional[MutableMapping[str, str]] = None,
timeout: Optional[int] = None,
) -> bytes:
"""Make an HTTP request.

def __init__(self, credentials):
self.credentials = credentials

@abc.abstractmethod
def request(
self,
method,
url,
data=None,
headers=None,
max_allowed_time=None,
timeout=_DEFAULT_TIMEOUT,
**kwargs
):
raise NotImplementedError("Request must be implemented")

@abc.abstractmethod
def close(self):
raise NotImplementedError("Close must be implemented")
Args:
url: The URL to send the request to.
method: The HTTP method to use.
body: The payload to include in the request body.
headers: The headers to include in the request.
timeout: The timeout in seconds.

Returns:
bytes: The response data as bytes.
"""
return b"Mock response" # Replace with actual request logic