Skip to content

feature: artifact-helper #750

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions openeo/extra/artifacts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from openeo.extra.artifacts.artifact_helper import ArtifactHelper
Empty file.
96 changes: 96 additions & 0 deletions openeo/extra/artifacts/_s3sts/artifact_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from __future__ import annotations

import datetime
from typing import TYPE_CHECKING, Optional

if TYPE_CHECKING:
from types_boto3_s3.client import S3Client

from pathlib import Path

from boto3.s3.transfer import TransferConfig

from openeo.extra.artifacts._s3sts.config import S3STSConfig
from openeo.extra.artifacts._s3sts.model import S3URI, AWSSTSCredentials
from openeo.extra.artifacts._s3sts.sts import OpenEOSTSClient
from openeo.extra.artifacts.artifact_helper_abc import ArtifactHelperABC
from openeo.rest.connection import Connection


class S3STSArtifactHelper(ArtifactHelperABC):
# From what size will we switch to multi-part-upload
MULTIPART_THRESHOLD_IN_MB = 50

def __init__(self, conn: Connection, config: S3STSConfig):
super().__init__(config)
self.conn = conn
self.config = config
self._creds = self.get_new_creds()
self._s3: Optional[S3Client] = None

@classmethod
def _from_openeo_connection(cls, conn: Connection, config: S3STSConfig) -> S3STSArtifactHelper:
return S3STSArtifactHelper(conn, config=config)

def get_new_creds(self) -> AWSSTSCredentials:
sts = OpenEOSTSClient(config=self.config)
return sts.assume_from_openeo_connection(self.conn)

def _user_prefix(self) -> str:
"""Each user has its own prefix retrieve it"""
return self._creds.get_user_hash()

def _get_upload_prefix(self) -> str:
# TODO: replace utcnow when `datetime.datetime.now(datetime.UTC)` in oldest supported Python version
return f"{self._user_prefix()}/{datetime.datetime.utcnow().strftime('%Y/%m/%d')}/"

def _get_upload_key(self, object_name: str) -> str:
return f"{self._get_upload_prefix()}{object_name}"

@staticmethod
def get_object_name_from_path(path: str | Path) -> str:
if isinstance(path, str):
path = Path(path)
return path.name

def _get_s3_client(self):
# TODO: validate whether credentials are still reasonably long valid
# and if not refresh credentials and rebuild client
if self._s3 is None:
self._s3 = self.config.build_client("s3", session_kwargs=self._creds.as_kwargs())
return self._s3

def upload_file(self, path: str | Path, object_name: str = "") -> S3URI:
"""
Upload a file to a backend understanding the S3 API

:param path A file path to the file that must be uploaded
:param object_name: Optional the final part of the name to be uploaded. If omitted the filename is used.

:return: `S3URI` A S3URI that points to the uploaded file in the S3 compatible backend
"""
mb = 1024**2
config = TransferConfig(multipart_threshold=self.MULTIPART_THRESHOLD_IN_MB * mb)
bucket = self.config.bucket
key = self._get_upload_key(object_name or self.get_object_name_from_path(path))
self._get_s3_client().upload_file(str(path), bucket, key, Config=config)
return S3URI(bucket, key)

def get_presigned_url(self, storage_uri: S3URI, expires_in_seconds: int = 7 * 3600 * 24) -> str:
"""
Get a presigned URL to allow retrieval of an object.

:param storage_uri `S3URI` A S3URI that points to the uploaded file in the S3 compatible backend
:param expires_in_seconds: Optional the number of seconds the link is valid for (defaults to 7 days)

:return: `str` A HTTP url that can be used to download a file. It also supports Range header in its requests.
"""
url = self._get_s3_client().generate_presigned_url(
"get_object", Params={"Bucket": storage_uri.bucket, "Key": storage_uri.key}, ExpiresIn=expires_in_seconds
)
assert isinstance(self._config, S3STSConfig)
return self._config.add_trace_id_qp_if_needed(url)

@classmethod
def _get_default_storage_config(cls) -> S3STSConfig:
return S3STSConfig()
121 changes: 121 additions & 0 deletions openeo/extra/artifacts/_s3sts/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from dataclasses import dataclass, field
from typing import Optional

import boto3
import botocore
from botocore.config import Config

from openeo.extra.artifacts._s3sts.tracer import (
add_trace_id,
add_trace_id_as_query_parameter,
)
from openeo.extra.artifacts.backend import ProviderConfig
from openeo.extra.artifacts.config import StorageConfig
from openeo.utils.version import ComparableVersion

if ComparableVersion(botocore.__version__).below("1.36.0"):
# Before 1.36 checksuming was not done by default anyway and therefore
# there was no opt-out.
no_default_checksum_config = Config()
else:
no_default_checksum_config = Config(
request_checksum_calculation="when_required",
)


DISABLE_TRACING_TRACE_ID = "00000000-0000-0000-0000-000000000000"


@dataclass(frozen=True)
class S3STSConfig(StorageConfig):
"""The s3 endpoint url protocol:://fqdn[:portnumber]"""

s3_endpoint: Optional[str] = None
"""The sts endpoint url protocol:://fqdn[:portnumber]"""
sts_endpoint: Optional[str] = None
"""The trace_id is if you want to send a uuid4 identifier to the backend"""
trace_id: str = DISABLE_TRACING_TRACE_ID
"""You can change the botocore_config used but this is an expert option"""
botocore_config: Config = field(default_factory=lambda: no_default_checksum_config)
"""The role ARN to be assumed"""
role: Optional[str] = None
"""The bucket to store the object into"""
bucket: Optional[str] = None

def _load_connection_provided_config(self, provider_config: ProviderConfig) -> None:
if self.s3_endpoint is None:
object.__setattr__(self, "s3_endpoint", provider_config["s3_endpoint"])

if self.sts_endpoint is None:
object.__setattr__(self, "sts_endpoint", provider_config["sts_endpoint"])

if self.role is None:
object.__setattr__(self, "role", provider_config["role"])

if self.bucket is None:
object.__setattr__(self, "bucket", provider_config["bucket"])

def should_trace(self) -> bool:
return self.trace_id != DISABLE_TRACING_TRACE_ID

def build_client(self, service_name: str, session_kwargs: Optional[dict] = None):
"""
Build a boto3 client for an OpenEO service provider.

service_name is the service you want to consume: s3|sts
session_kwargs: a dictionary with keyword arguments that will be passed when creating the boto session
"""
session_kwargs = session_kwargs or {}
session = boto3.Session(region_name=self._get_storage_region(), **session_kwargs)
client = session.client(
service_name,
endpoint_url=self._get_endpoint_url(service_name),
config=self.botocore_config,
)
if self.should_trace():
add_trace_id(client, self.trace_id)
return client

@staticmethod
def _remove_protocol_from_uri(uri: str):
uri_separator = "://"
idx = uri.find(uri_separator)
if idx < 0:
raise ValueError("_remove_protocol_from_uri must be of form protocol://...")
return uri[idx + len(uri_separator) :]

def _get_storage_region(self) -> str:
"""
S3 URIs follow the convention detailed on https://docs.aws.amazon.com/general/latest/gr/s3.html
"""
s3_names = ["s3", "s3-fips"]
reserved_words = ["dualstack", "prod", "stag", "dev"]
s3_endpoint_parts = self._remove_protocol_from_uri(self.s3_endpoint).split(".")
for s3_name in s3_names:
try:
old_idx = s3_endpoint_parts.index(s3_name)
idx = old_idx + 1
while idx != old_idx:
old_idx = idx
for reserved_word in reserved_words:
if s3_endpoint_parts[idx] in reserved_word:
idx += 1
return s3_endpoint_parts[idx]
except ValueError:
continue
raise ValueError(f"Cannot determine region from {self.s3_endpoint}")

def _get_endpoint_url(self, service_name: str) -> str:
if service_name == "s3":
return self.s3_endpoint
elif service_name == "sts":
return self.sts_endpoint
raise ValueError(f"Unsupported service {service_name}")

def add_trace_id_qp_if_needed(self, url: str) -> str:
if not self.should_trace():
return url
return add_trace_id_as_query_parameter(url, self.trace_id)

def get_sts_role_arn(self) -> str:
return self.role
66 changes: 66 additions & 0 deletions openeo/extra/artifacts/_s3sts/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from __future__ import annotations

import datetime
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from types_boto3_sts.type_defs import AssumeRoleWithWebIdentityResponseTypeDef

import hashlib
from dataclasses import dataclass
from urllib.parse import urlparse

from openeo.extra.artifacts.uri import StorageURI


@dataclass(frozen=True)
class AWSSTSCredentials:
aws_access_key_id: str
aws_secret_access_key: str
aws_session_token: str
subject_from_web_identity_token: str
expiration: datetime.datetime

@classmethod
def from_assume_role_response(cls, resp: AssumeRoleWithWebIdentityResponseTypeDef) -> AWSSTSCredentials:
d = resp["Credentials"]
return AWSSTSCredentials(
aws_access_key_id=d["AccessKeyId"],
aws_secret_access_key=d["SecretAccessKey"],
aws_session_token=d["SessionToken"],
subject_from_web_identity_token=resp["SubjectFromWebIdentityToken"],
expiration=d["Expiration"],
)

def as_kwargs(self) -> dict:
return {
"aws_access_key_id": self.aws_access_key_id,
"aws_secret_access_key": self.aws_secret_access_key,
"aws_session_token": self.aws_session_token,
}

def get_user_hash(self) -> str:
hash_object = hashlib.sha1(self.subject_from_web_identity_token.encode())
return hash_object.hexdigest()


@dataclass(frozen=True)
class S3URI(StorageURI):
bucket: str
key: str

@classmethod
def from_str(cls, uri: str) -> S3URI:
_parsed = urlparse(uri, allow_fragments=False)
if _parsed.scheme != "s3":
raise ValueError(f"Input {uri} is not a valid S3 URI should be of form s3://<bucket>/<key>")
bucket = _parsed.netloc
if _parsed.query:
key = _parsed.path.lstrip("/") + "?" + _parsed.query
else:
key = _parsed.path.lstrip("/")

return S3URI(bucket, key)

def to_string(self) -> str:
return f"s3://{self.bucket}/{self.key}"
43 changes: 43 additions & 0 deletions openeo/extra/artifacts/_s3sts/sts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from types_boto3_sts.client import STSClient

from openeo.extra.artifacts._s3sts.config import S3STSConfig
from openeo.extra.artifacts._s3sts.model import AWSSTSCredentials
from openeo.extra.artifacts.exceptions import ProviderSpecificException
from openeo.rest.auth.auth import BearerAuth
from openeo.rest.connection import Connection
from openeo.util import Rfc3339


class OpenEOSTSClient:
def __init__(self, config: S3STSConfig):
self.config = config

def assume_from_openeo_connection(self, conn: Connection) -> AWSSTSCredentials:
"""
Takes an OpenEO connection object and returns temporary credentials to interact with S3
"""
auth = conn.auth
assert auth is not None
if not isinstance(auth, BearerAuth):
raise ProviderSpecificException("Only connections that have BearerAuth can be used.")
auth_token = auth.bearer.split("/")

return AWSSTSCredentials.from_assume_role_response(
self._get_sts_client().assume_role_with_web_identity(
RoleArn=self._get_aws_access_role(),
RoleSessionName=f"artifact-helper-{Rfc3339().now_utc()}",
WebIdentityToken=auth_token[2],
DurationSeconds=43200,
)
)

def _get_sts_client(self) -> STSClient:
return self.config.build_client("sts")

def _get_aws_access_role(self) -> str:
return self.config.role
31 changes: 31 additions & 0 deletions openeo/extra/artifacts/_s3sts/tracer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import logging
from typing import Callable

"""
The trace helps to pass on an X-Request-ID header value in all requests made by a
boto3 call.
"""

TRACE_ID_KEY = "X-Request-ID"


def create_header_adder(request_id: str) -> Callable:
def add_request_id_header(request, **kwargs) -> None:
logger = logging.getLogger("openeo.extra.artifacts")
signature_version = kwargs.get("signature_version", "unknown")
if "query" in signature_version:
logger.debug("Do not add trace header for requests using query parameters instead of headers")
return
logger.debug("Adding trace id: {request_id}")
request.headers.add_header(TRACE_ID_KEY, request_id)

return add_request_id_header


def add_trace_id(client, trace_id: str = "") -> None:
header_adder = create_header_adder(trace_id)
client.meta.events.register("before-sign.s3", header_adder)


def add_trace_id_as_query_parameter(url, trace_id: str) -> str:
return f"{url}&{TRACE_ID_KEY}={trace_id}"
Loading