diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..852c375b --- /dev/null +++ b/.flake8 @@ -0,0 +1,20 @@ +# This file is part of the EESSI filesystem layer, +# see https://github.com/EESSI/filesystem-layer +# +# author: Thomas Roeblitz (@trz42) +# +# license: GPLv2 +# + +[flake8] +exclude = + scripts/check-stratum-servers.py, + scripts/automated_ingestion/automated_ingestion.py, + scripts/automated_ingestion/eessitarball.py, + scripts/automated_ingestion/utils.py + +# ignore "Black would make changes" produced by flake8-black +# see also https://github.com/houndci/hound/issues/1769 +extend-ignore = BLK100 + +max-line-length = 120 diff --git a/.github/workflows/test-ingest-python-code.yml b/.github/workflows/test-ingest-python-code.yml new file mode 100644 index 00000000..9e341783 --- /dev/null +++ b/.github/workflows/test-ingest-python-code.yml @@ -0,0 +1,49 @@ +# This file is part of the EESSI filesystem layer, +# see https://github.com/EESSI/filesystem-layer +# +# author: Thomas Roeblitz (@trz42) +# +# license: GPLv2 +# + +name: Run tests +on: [push, pull_request] +# Declare default permissions as read only. +permissions: read-all +jobs: + test: + runs-on: ubuntu-24.04 + strategy: + matrix: + # for now, only test with Python 3.9+ (since we're testing in Ubuntu 24.04) + #python: [3.6, 3.7, 3.8, 3.9, '3.10', '3.11'] + python: ['3.9', '3.10', '3.11'] + fail-fast: false + steps: + - name: checkout + uses: actions/checkout@93ea575cb5d8a053eaa0ac8fa3b40d7e05a33cc8 # v3.1.0 + + - name: set up Python + uses: actions/setup-python@13ae5bb136fac2878aff31522b9efb785519f984 # v4.3.0 + with: + python-version: ${{matrix.python}} + + - name: Install required Python packages + pytest + flake8 + run: | + python -m pip install --upgrade pip + python -m pip install -r scripts/automated_ingestion/requirements.txt + python -m pip install pytest + python -m pip install --upgrade flake8 + + - name: Run test suite (without coverage) + run: | + ./scripts/automated_ingestion/pytest.sh scripts/automated_ingestion --verbose + + - name: Run test suite (with coverage) + run: | + python -m pip install pytest-cov + ./scripts/automated_ingestion/pytest.sh scripts/automated_ingestion -q --cov=scripts/automated_ingestion/eessi_logging.py + + - name: Run flake8 to verify PEP8-compliance of Python code + run: | + flake8 scripts/automated_ingestion --exclude=scripts/automated_ingestion/automated_ingestion.py,scripts/automated_ingestion/eessitarball.py \ No newline at end of file diff --git a/.gitignore b/.gitignore index 39af2bac..893c00e4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ build hosts +.coverage +**/__pycache__ diff --git a/scripts/automated_ingestion/.coveragerc b/scripts/automated_ingestion/.coveragerc new file mode 100644 index 00000000..2941a1ed --- /dev/null +++ b/scripts/automated_ingestion/.coveragerc @@ -0,0 +1,6 @@ +[run] +omit = + scripts/automated_ingestion/automated_ingestion.py + scripts/automated_ingestion/eessitarball.py + scripts/automated_ingestion/utils.py + scripts/automated_ingestion/unit_tests/*.py diff --git a/scripts/automated_ingestion/automated_ingestion.py b/scripts/automated_ingestion/automated_ingestion.py index 92dac552..7a7f9dc8 100755 --- a/scripts/automated_ingestion/automated_ingestion.py +++ b/scripts/automated_ingestion/automated_ingestion.py @@ -81,7 +81,7 @@ def parse_args(): return args -@pid.decorator.pidfile('automated_ingestion.pid') +@pidfile('shared_lock.pid') # noqa: F401 def main(): """Main function.""" args = parse_args() diff --git a/scripts/automated_ingestion/eessi_data_object.py b/scripts/automated_ingestion/eessi_data_object.py new file mode 100644 index 00000000..4989f24c --- /dev/null +++ b/scripts/automated_ingestion/eessi_data_object.py @@ -0,0 +1,344 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +import configparser +import subprocess + +from eessi_logging import log_function_entry_exit, log_message, LoggingScope +from eessi_remote_storage_client import DownloadMode, EESSIRemoteStorageClient + + +@dataclass +class EESSIDataAndSignatureObject: + """Class representing an EESSI data file and its signature in remote storage and locally.""" + + # configuration + config: configparser.ConfigParser + + # remote paths + remote_file_path: str # path to data file in remote storage + remote_sig_path: str # path to signature file in remote storage + + # local paths + local_file_path: Path # path to local data file + local_sig_path: Path # path to local signature file + + # remote storage client + remote_client: EESSIRemoteStorageClient + + @log_function_entry_exit() + def __init__( + self, + config: configparser.ConfigParser, + remote_file_path: str, + remote_client: EESSIRemoteStorageClient, + ): + """ + Initialize an EESSI data and signature object handler. + + Args: + config: configuration object containing remote storage and local directory information + remote_file_path: path to data file in remote storage + remote_client: remote storage client implementing the EESSIRemoteStorageClient protocol + """ + self.config = config + self.remote_file_path = remote_file_path + sig_ext = config["signatures"]["signature_file_extension"] + self.remote_sig_path = remote_file_path + sig_ext + + # set up local paths + local_dir = Path(config["paths"]["download_dir"]) + # use the full remote path structure, removing any leading slashes + remote_path = remote_file_path.lstrip("/") + self.local_file_path = local_dir.joinpath(remote_path) + self.local_sig_path = local_dir.joinpath(remote_path + sig_ext) + self.remote_client = remote_client + + log_message(LoggingScope.DEBUG, "DEBUG", "Initialized EESSIDataAndSignatureObject for '%s'", remote_file_path) + log_message(LoggingScope.DEBUG, "DEBUG", "Local file path: '%s'", self.local_file_path) + log_message(LoggingScope.DEBUG, "DEBUG", "Local signature path: '%s'", self.local_sig_path) + + @log_function_entry_exit() + def _get_etag_file_path(self, local_path: Path) -> Path: + """Get the path to the .etag file for a given local file.""" + return local_path.with_suffix(".etag") + + @log_function_entry_exit() + def _get_local_etag(self, local_path: Path) -> Optional[str]: + """Get the ETag of a local file from its .etag file.""" + etag_path = self._get_etag_file_path(local_path) + if etag_path.exists(): + try: + with open(etag_path, "r") as f: + return f.read().strip() + except Exception as err: + log_message(LoggingScope.DEBUG, "WARNING", "Failed to read ETag file '%s': '%s'", etag_path, str(err)) + return None + return None + + @log_function_entry_exit() + def get_etags(self) -> tuple[Optional[str], Optional[str]]: + """ + Get the ETags of both the data file and its signature. + + Returns: + Tuple containing (data_file_etag, signature_file_etag) + """ + return ( + self._get_local_etag(self.local_file_path), + self._get_local_etag(self.local_sig_path) + ) + + @log_function_entry_exit() + def verify_signature(self) -> bool: + """ + Verify the signature of the data file using the corresponding signature file. + + Returns: + bool: True if the signature is valid or if signatures are not required, False otherwise + """ + # check if signature file exists + if not self.local_sig_path.exists(): + log_message(LoggingScope.VERIFICATION, "WARNING", "Signature file '%s' is missing", + self.local_sig_path) + + # if signatures are required, return failure + if self.config["signatures"].getboolean("signatures_required", True): + log_message(LoggingScope.ERROR, "ERROR", "Signature file '%s' is missing and signatures are required", + self.local_sig_path) + return False + else: + log_message(LoggingScope.VERIFICATION, "INFO", + "Signature file '%s' is missing, but signatures are not required", + self.local_sig_path) + return True + + # if signatures are provided, we should always verify them, regardless of the signatures_required setting + verify_runenv = self.config["signatures"]["signature_verification_runenv"].split() + verify_script = self.config["signatures"]["signature_verification_script"] + allowed_signers_file = self.config["signatures"]["allowed_signers_file"] + + # check if verification tools exist + if not Path(verify_script).exists(): + log_message(LoggingScope.ERROR, "ERROR", + "Unable to verify signature: verification script '%s' does not exist", verify_script) + return False + + if not Path(allowed_signers_file).exists(): + log_message(LoggingScope.ERROR, "ERROR", + "Unable to verify signature: allowed signers file '%s' does not exist", allowed_signers_file) + return False + + # run the verification command with named parameters + cmd = verify_runenv + [ + verify_script, + "--verify", + "--allowed-signers-file", allowed_signers_file, + "--file", str(self.local_file_path), + "--signature-file", str(self.local_sig_path) + ] + log_message(LoggingScope.VERIFICATION, "INFO", "Running command: '%s'", " ".join(cmd)) + + try: + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode == 0: + log_message(LoggingScope.VERIFICATION, "INFO", + "Successfully verified signature for '%s'", self.local_file_path) + log_message(LoggingScope.VERIFICATION, "DEBUG", " stdout: '%s'", result.stdout) + log_message(LoggingScope.VERIFICATION, "DEBUG", " stderr: '%s'", result.stderr) + return True + else: + log_message(LoggingScope.ERROR, "ERROR", + "Signature verification failed for '%s'", self.local_file_path) + log_message(LoggingScope.ERROR, "ERROR", " stdout: '%s'", result.stdout) + log_message(LoggingScope.ERROR, "ERROR", " stderr: '%s'", result.stderr) + return False + except Exception as err: + log_message(LoggingScope.ERROR, "ERROR", + "Error during signature verification for '%s': '%s'", + self.local_file_path, str(err)) + return False + + @log_function_entry_exit() + def download(self, mode: DownloadMode = DownloadMode.CHECK_REMOTE) -> bool: + """ + Download data file and signature based on the specified mode. + + Args: + mode: Download mode to use + + Returns: + True if files were downloaded, False otherwise + """ + # if mode is FORCE, we always download regardless of local or remote state + if mode == DownloadMode.FORCE: + should_download = True + log_message(LoggingScope.DOWNLOAD, "INFO", "Forcing download of '%s'", self.remote_file_path) + # for CHECK_REMOTE mode, check if we can optimize + elif mode == DownloadMode.CHECK_REMOTE: + # optimization: check if local files exist first + local_files_exist = ( + self.local_file_path.exists() and + self.local_sig_path.exists() + ) + + # if files don't exist locally, we can skip ETag checks + if not local_files_exist: + log_message(LoggingScope.DOWNLOAD, "INFO", + "Local files missing, skipping ETag checks and downloading '%s'", + self.remote_file_path) + should_download = True + else: + # first check if we have local ETags + try: + local_file_etag = self._get_local_etag(self.local_file_path) + local_sig_etag = self._get_local_etag(self.local_sig_path) + + if local_file_etag: + log_message(LoggingScope.DOWNLOAD, "DEBUG", "Local file ETag: '%s'", local_file_etag) + else: + log_message(LoggingScope.DOWNLOAD, "DEBUG", "No local file ETag found") + if local_sig_etag: + log_message(LoggingScope.DOWNLOAD, "DEBUG", "Local signature ETag: '%s'", local_sig_etag) + else: + log_message(LoggingScope.DOWNLOAD, "DEBUG", "No local signature ETag found") + + # if we don't have local ETags, we need to download + if not local_file_etag or not local_sig_etag: + should_download = True + log_message(LoggingScope.DOWNLOAD, "INFO", "Missing local ETags, downloading '%s'", + self.remote_file_path) + else: + # get remote ETags and compare + remote_file_etag = self.remote_client.get_metadata(self.remote_file_path)["ETag"] + remote_sig_etag = self.remote_client.get_metadata(self.remote_sig_path)["ETag"] + log_message(LoggingScope.DOWNLOAD, "DEBUG", "Remote file ETag: '%s'", remote_file_etag) + log_message(LoggingScope.DOWNLOAD, "DEBUG", "Remote signature ETag: '%s'", remote_sig_etag) + + should_download = ( + remote_file_etag != local_file_etag or + remote_sig_etag != local_sig_etag + ) + if should_download: + if remote_file_etag != local_file_etag: + log_message(LoggingScope.DOWNLOAD, "INFO", "File ETag changed from '%s' to '%s'", + local_file_etag, remote_file_etag) + if remote_sig_etag != local_sig_etag: + log_message(LoggingScope.DOWNLOAD, "INFO", "Signature ETag changed from '%s' to '%s'", + local_sig_etag, remote_sig_etag) + log_message(LoggingScope.DOWNLOAD, "INFO", "Remote files have changed, downloading '%s'", + self.remote_file_path) + else: + log_message(LoggingScope.DOWNLOAD, "INFO", + "Remote files unchanged, skipping download of '%s'", + self.remote_file_path) + except Exception as etag_err: + # if we get any error with ETags, we'll just download the files + log_message(LoggingScope.DOWNLOAD, "DEBUG", "Error handling ETags, will download files: '%s'", + str(etag_err)) + should_download = True + else: # check_local + should_download = ( + not self.local_file_path.exists() or + not self.local_sig_path.exists() + ) + if should_download: + if not self.local_file_path.exists(): + log_message(LoggingScope.DOWNLOAD, "INFO", "Local file missing: '%s'", self.local_file_path) + if not self.local_sig_path.exists(): + log_message(LoggingScope.DOWNLOAD, "INFO", "Local signature missing: '%s'", self.local_sig_path) + log_message(LoggingScope.DOWNLOAD, "INFO", "Local files missing, downloading '%s'", + self.remote_file_path) + else: + log_message(LoggingScope.DOWNLOAD, "INFO", "Local files exist, skipping download of '%s'", + self.remote_file_path) + + if not should_download: + return False + + # ensure local directory exists + self.local_file_path.parent.mkdir(parents=True, exist_ok=True) + + # download files + try: + # download the main file first + self.remote_client.download(self.remote_file_path, str(self.local_file_path)) + + # get and log the ETag of the downloaded file + try: + file_etag = self._get_local_etag(self.local_file_path) + log_message(LoggingScope.DOWNLOAD, "DEBUG", "Downloaded '%s' with ETag: '%s'", + self.remote_file_path, file_etag) + except Exception as etag_err: + log_message(LoggingScope.DOWNLOAD, "DEBUG", "Error getting ETag for '%s': '%s'", + self.remote_file_path, str(etag_err)) + + # try to download the signature file + try: + self.remote_client.download(self.remote_sig_path, str(self.local_sig_path)) + try: + sig_etag = self._get_local_etag(self.local_sig_path) + log_message(LoggingScope.DOWNLOAD, "DEBUG", "Downloaded '%s' with ETag: '%s'", + self.remote_sig_path, sig_etag) + except Exception as etag_err: + log_message(LoggingScope.DOWNLOAD, "DEBUG", "Error getting ETag for '%s': '%s'", + self.remote_sig_path, str(etag_err)) + log_message(LoggingScope.DOWNLOAD, "INFO", "Successfully downloaded '%s' and its signature", + self.remote_file_path) + except Exception as sig_err: + # check if signatures are required + if self.config["signatures"].getboolean("signatures_required", True): + # if signatures are required, clean up everything since we can't proceed + if self.local_file_path.exists(): + self.local_file_path.unlink() + # clean up etag files regardless of whether their data files exist + file_etag_path = self._get_etag_file_path(self.local_file_path) + if file_etag_path.exists(): + file_etag_path.unlink() + sig_etag_path = self._get_etag_file_path(self.local_sig_path) + if sig_etag_path.exists(): + sig_etag_path.unlink() + log_message(LoggingScope.ERROR, "ERROR", "Failed to download required signature for '%s': '%s'", + self.remote_file_path, str(sig_err)) + raise + else: + # if signatures are optional, just clean up any partial signature files + if self.local_sig_path.exists(): + self.local_sig_path.unlink() + sig_etag_path = self._get_etag_file_path(self.local_sig_path) + if sig_etag_path.exists(): + sig_etag_path.unlink() + log_message(LoggingScope.DOWNLOAD, "WARNING", + "Failed to download optional signature for '%s': '%s'", + self.remote_file_path, str(sig_err)) + log_message(LoggingScope.DOWNLOAD, "INFO", + "Successfully downloaded '%s' (signature optional)", + self.remote_file_path) + + return True + except Exception as err: + # this catch block is only for errors in the main file download + # clean up partially downloaded files and their etags + if self.local_file_path.exists(): + self.local_file_path.unlink() + if self.local_sig_path.exists(): + self.local_sig_path.unlink() + # clean up etag files regardless of whether their data files exist + file_etag_path = self._get_etag_file_path(self.local_file_path) + if file_etag_path.exists(): + file_etag_path.unlink() + sig_etag_path = self._get_etag_file_path(self.local_sig_path) + if sig_etag_path.exists(): + sig_etag_path.unlink() + log_message(LoggingScope.ERROR, "ERROR", "Failed to download '%s': '%s'", self.remote_file_path, str(err)) + raise + + @log_function_entry_exit() + def get_url(self) -> str: + """Get the URL of the data file.""" + return f"https://{self.remote_client.bucket}.s3.amazonaws.com/{self.remote_file_path}" + + def __str__(self) -> str: + """Return a string representation of the EESSI data and signature object.""" + return f"EESSIDataAndSignatureObject({self.remote_file_path})" diff --git a/scripts/automated_ingestion/eessi_logging.py b/scripts/automated_ingestion/eessi_logging.py new file mode 100644 index 00000000..f94c8495 --- /dev/null +++ b/scripts/automated_ingestion/eessi_logging.py @@ -0,0 +1,262 @@ +import functools +import inspect +import logging +import os +import sys +import time + +from enum import IntFlag, auto +from typing import Callable, Union + + +LOG_LEVELS = { + 'DEBUG': logging.DEBUG, + 'INFO': logging.INFO, + 'WARNING': logging.WARNING, + 'ERROR': logging.ERROR, + 'CRITICAL': logging.CRITICAL +} + + +class LoggingScope(IntFlag): + """Enumeration of different logging scopes.""" + NONE = 0 + FUNC_ENTRY_EXIT = auto() # Function entry/exit logging + DOWNLOAD = auto() # Logging related to file downloads + VERIFICATION = auto() # Logging related to signature and checksum verification + STATE_OPS = auto() # Logging related to tarball state operations + GITHUB_OPS = auto() # Logging related to GitHub operations (PRs, issues, etc.) + GROUP_OPS = auto() # Logging related to tarball group operations + TASK_OPS = auto() # Logging related to task operations + TASK_OPS_DETAILS = auto() # Logging related to task operations (detailed) + ERROR = auto() # Error logging (separate from other scopes for easier filtering) + DEBUG = auto() # Debug-level logging (separate from other scopes for easier filtering) + ALL = (FUNC_ENTRY_EXIT | DOWNLOAD | VERIFICATION | STATE_OPS | + GITHUB_OPS | GROUP_OPS | TASK_OPS | TASK_OPS_DETAILS | ERROR | DEBUG) + + +# Global setting for logging scopes +ENABLED_LOGGING_SCOPES = LoggingScope.NONE + + +# Global variable to track call stack depth +_call_stack_depth = 0 + + +def error(msg, code=1): + """Print an error and exit.""" + log_message(LoggingScope.ERROR, 'ERROR', msg) + sys.exit(code) + + +def is_logging_scope_enabled(scope: LoggingScope) -> bool: + """Check if a specific logging scope is enabled.""" + return bool(ENABLED_LOGGING_SCOPES & scope) + + +def log_function_entry_exit(logger: logging.Logger = None) -> Callable: + """ + Decorator that logs function entry and exit with timing information. + Only logs if the FUNC_ENTRY_EXIT scope is enabled. + + Args: + logger: Optional logger instance. If not provided, uses the module's logger. + """ + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + global _call_stack_depth + + if not is_logging_scope_enabled(LoggingScope.FUNC_ENTRY_EXIT): + return func(*args, **kwargs) + + if logger is None: + log = logging.getLogger(func.__module__) + else: + log = logger + + # Get context information if available + context = "" + if len(args) > 0 and hasattr(args[0], 'object'): + # For EessiTarball methods, show the tarball name and state + tarball = args[0] + filename = os.path.basename(tarball.object) + + # Format filename to show important parts + if len(filename) > 30: + parts = filename.split('-') + if len(parts) >= 6: # Ensure we have all required parts + # Get version, component, last part of architecture, and epoch + version = parts[1] + component = parts[2] + arch_last = parts[-2].split('-')[-1] # Last part of architecture + epoch = parts[-1] # includes file extension + filename = f"{version}-{component}-{arch_last}-{epoch}" + else: + # Fallback to simple truncation if format doesn't match + filename = f"{filename[:15]}...{filename[-12:]}" + + context = f" [{filename}" + if hasattr(tarball, 'state'): + context += f" in {tarball.state}" + context += "]" + + # Create indentation based on call stack depth + indent = " " * _call_stack_depth + + # Get file name and line number where the function is defined + file_name = os.path.basename(inspect.getsourcefile(func)) + source_lines, start_line = inspect.getsourcelines(func) + # Find the line with the actual function definition + def_line = next(i for i, line in enumerate(source_lines) if line.strip().startswith('def ')) + def_line_no = start_line + def_line + # Find the last non-empty line of the function + last_line = next(i for i, line in enumerate(reversed(source_lines)) if line.strip()) + last_line_no = start_line + len(source_lines) - 1 - last_line + + start_time = time.time() + log.info(f"{indent}[FUNC_ENTRY_EXIT] Entering {func.__name__} at {file_name}:{def_line_no}{context}") + _call_stack_depth += 1 + try: + result = func(*args, **kwargs) + _call_stack_depth -= 1 + end_time = time.time() + # For normal returns, show the last line of the function + log.info(f"{indent}[FUNC_ENTRY_EXIT] Leaving {func.__name__} at {file_name}:{last_line_no}" + f"{context} (took {end_time - start_time:.2f}s)") + return result + except Exception as err: + _call_stack_depth -= 1 + end_time = time.time() + # For exceptions, try to get the line number from the exception + try: + exc_line_no = err.__traceback__.tb_lineno + except AttributeError: + exc_line_no = last_line_no + log.info(f"{indent}[FUNC_ENTRY_EXIT] Leaving {func.__name__} at {file_name}:{exc_line_no}" + f"{context} with exception (took {end_time - start_time:.2f}s)") + raise err + return wrapper + return decorator + + +def log_message(scope, level, msg, *args, logger=None, **kwargs): + """ + Log a message if either: + 1. The specified scope is enabled, OR + 2. The current log level is equal to or higher than the specified level + + Args: + scope: LoggingScope value indicating which scope this logging belongs to + level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) + msg: Message to log + logger: Optional logger instance. If not provided, uses the root logger. + *args, **kwargs: Additional arguments to pass to the logging function + """ + log = logger or logging.getLogger() + log_level = getattr(logging, level.upper()) + + # Check if either condition is met + if not (is_logging_scope_enabled(scope) or log_level >= log.getEffectiveLevel()): + return + + # Create indentation based on call stack depth + indent = " " * _call_stack_depth + # Add scope to the message + scoped_msg = f"[{scope.name}] {msg}" + indented_msg = f"{indent}{scoped_msg}" + + # If scope is enabled, use the temporary handler + if is_logging_scope_enabled(scope): + # Save original handlers + original_handlers = list(log.handlers) + + # Create a temporary handler that accepts all levels + temp_handler = logging.StreamHandler(sys.stdout) + temp_handler.setLevel(logging.DEBUG) + temp_handler.setFormatter(logging.Formatter('%(levelname)-8s: %(message)s')) + + try: + # Remove existing handlers temporarily + for handler in original_handlers: + log.removeHandler(handler) + + # Add temporary handler + log.addHandler(temp_handler) + + # Log the message + log_func = getattr(log, level.lower()) + log_func(indented_msg, *args, **kwargs) + finally: + log.removeHandler(temp_handler) + # Restore original handlers + for handler in original_handlers: + if handler not in log.handlers: + log.addHandler(handler) + # Only use normal logging if scope is not enabled AND level is high enough + elif not is_logging_scope_enabled(scope) and log_level >= log.getEffectiveLevel(): + # Use normal logging with level check + log_func = getattr(log, level.lower()) + log_func(indented_msg, *args, **kwargs) + + +def set_logging_scopes(scopes: Union[LoggingScope, str, list[str]]) -> None: + """ + Set the enabled logging scopes. + + Args: + scopes: Can be + - A LoggingScope value + - A string with comma-separated values using +/- syntax: + - "+SCOPE" to enable a scope + - "-SCOPE" to disable a scope + - "ALL" or "+ALL" to enable all scopes + - "-ALL" to disable all scopes + Examples: + "+FUNC_ENTRY_EXIT" # Enable only function entry/exit + "+FUNC_ENTRY_EXIT,-EXAMPLE_SCOPE" # Enable function entry/exit but disable example + "+ALL,-FUNC_ENTRY_EXIT" # Enable all scopes except function entry/exit + """ + global ENABLED_LOGGING_SCOPES + + if isinstance(scopes, LoggingScope): + ENABLED_LOGGING_SCOPES = scopes + return + + if isinstance(scopes, str): + # Start with no scopes enabled + ENABLED_LOGGING_SCOPES = LoggingScope.NONE + + # Split into individual scope specifications + scope_specs = [s.strip() for s in scopes.split(",")] + + for spec in scope_specs: + if not spec: + continue + + # Check for ALL special case + if spec.upper() in ["ALL", "+ALL"]: + ENABLED_LOGGING_SCOPES = LoggingScope.ALL + continue + elif spec.upper() == "-ALL": + ENABLED_LOGGING_SCOPES = LoggingScope.NONE + continue + + # Parse scope name and operation + operation = spec[0] + scope_name = spec[1:].strip().upper() + + try: + scope_enum = LoggingScope[scope_name] + if operation == '+': + ENABLED_LOGGING_SCOPES |= scope_enum + elif operation == '-': + ENABLED_LOGGING_SCOPES &= ~scope_enum + else: + logging.warning(f"Invalid operation '{operation}' in scope specification: {spec}") + except KeyError: + logging.warning(f"Unknown logging scope: {scope_name}") + + elif isinstance(scopes, list): + # Convert list to comma-separated string and process + set_logging_scopes(",".join(scopes)) diff --git a/scripts/automated_ingestion/eessi_remote_storage_client.py b/scripts/automated_ingestion/eessi_remote_storage_client.py new file mode 100644 index 00000000..9f83d721 --- /dev/null +++ b/scripts/automated_ingestion/eessi_remote_storage_client.py @@ -0,0 +1,34 @@ +from enum import Enum +from typing import Protocol, runtime_checkable + + +class DownloadMode(Enum): + """Enum defining different modes for downloading files.""" + FORCE = 'force' # Always download and overwrite + CHECK_REMOTE = 'check-remote' # Download if remote files have changed + CHECK_LOCAL = 'check-local' # Download if files don't exist locally (default) + + +@runtime_checkable +class EESSIRemoteStorageClient(Protocol): + """Protocol defining the interface for remote storage clients.""" + + def get_metadata(self, remote_path: str) -> dict: + """Get metadata about a remote object. + + Args: + remote_path: Path to the object in remote storage + + Returns: + Dictionary containing object metadata, including 'ETag' key + """ + ... + + def download(self, remote_path: str, local_path: str) -> None: + """Download a remote file to a local location. + + Args: + remote_path: Path to the object in remote storage + local_path: Local path where to save the file + """ + ... diff --git a/scripts/automated_ingestion/eessi_s3_bucket.py b/scripts/automated_ingestion/eessi_s3_bucket.py new file mode 100644 index 00000000..bc5a8822 --- /dev/null +++ b/scripts/automated_ingestion/eessi_s3_bucket.py @@ -0,0 +1,191 @@ +import os +from pathlib import Path +from typing import Dict, Optional + +import boto3 +from botocore.exceptions import ClientError +from eessi_logging import log_function_entry_exit, log_message, LoggingScope +from eessi_remote_storage_client import EESSIRemoteStorageClient + + +class EESSIS3Bucket(EESSIRemoteStorageClient): + """EESSI-specific S3 bucket implementation of the EESSIRemoteStorageClient protocol.""" + + @log_function_entry_exit() + def __init__(self, config, bucket_name: str): + """ + Initialize the EESSI S3 bucket. + + Args: + config: Configuration object containing: + - aws.access_key_id: AWS access key ID (optional, can use AWS_ACCESS_KEY_ID env var) + - aws.secret_access_key: AWS secret access key (optional, can use AWS_SECRET_ACCESS_KEY env var) + - aws.endpoint_url: Custom endpoint URL for S3-compatible backends (optional) + - aws.verify: SSL verification setting (optional) + - True: Verify SSL certificates (default) + - False: Skip SSL certificate verification + - str: Path to CA bundle file + bucket_name: Name of the S3 bucket to use + """ + self.bucket = bucket_name + + # get AWS credentials from environment or config + aws_access_key_id = os.getenv("AWS_ACCESS_KEY_ID") or config.get("secrets", "aws_access_key_id") + aws_secret_access_key = os.getenv("AWS_SECRET_ACCESS_KEY") or config.get("secrets", "aws_secret_access_key") + + # configure boto3 client + client_config = {} + + # add endpoint URL if specified in config + if config.has_option("aws", "endpoint_url"): + client_config["endpoint_url"] = config["aws"]["endpoint_url"] + log_message(LoggingScope.DEBUG, "DEBUG", "Using custom endpoint URL: '%s'", client_config["endpoint_url"]) + + # add SSL verification if specified in config + if config.has_option("aws", "verify"): + verify = config["aws"]["verify"] + if verify.lower() == "false": + client_config["verify"] = False + log_message(LoggingScope.DEBUG, "WARNING", "SSL verification disabled") + elif verify.lower() == "true": + client_config["verify"] = True + else: + client_config["verify"] = verify # assume it's a path to CA bundle + log_message(LoggingScope.DEBUG, "DEBUG", "Using custom CA bundle: '%s'", verify) + + self.client = boto3.client( + "s3", + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + **client_config + ) + log_message(LoggingScope.DEBUG, "INFO", "Initialized S3 client for bucket: '%s'", self.bucket) + + @log_function_entry_exit() + def download(self, remote_path: str, local_path: str) -> None: + """ + Download an S3 object to a local location and store its ETag. + + Args: + remote_path: Path to the object in S3 + local_path: Local path where to save the file + """ + try: + log_message(LoggingScope.DOWNLOAD, "INFO", "Downloading '%s' to '%s'", remote_path, local_path) + self.client.download_file(Bucket=self.bucket, Key=remote_path, Filename=local_path) + log_message(LoggingScope.DOWNLOAD, "INFO", "Successfully downloaded '%s' to '%s'", remote_path, local_path) + except ClientError as err: + log_message(LoggingScope.ERROR, "ERROR", "Failed to download '%s': '%s'", remote_path, str(err)) + raise + + # get metadata first to obtain the ETag + metadata = self.get_metadata(remote_path) + etag = metadata["ETag"] + + # store the ETag + self._write_etag(local_path, etag) + + @log_function_entry_exit() + def download_file(self, key: str, filename: str) -> None: + """ + Download a file from S3 to a local file. + + Args: + key: The S3 key of the file to download + filename: The local path where the file should be saved + """ + self.client.download_file(self.bucket, key, filename) + + @log_function_entry_exit() + def get_bucket_url(self) -> str: + """ + Get the HTTPS URL for a bucket from an initialized boto3 client. + Works with both AWS S3 and MinIO/S3-compatible services. + """ + try: + # check if this is a custom endpoint (MinIO) or AWS S3 + endpoint_url = self.client.meta.endpoint_url + + if endpoint_url: + # custom endpoint (MinIO, DigitalOcean Spaces, etc.) + # most S3-compatible services use path-style URLs + bucket_url = f"{endpoint_url}/{self.bucket}" + else: + # AWS S3 (no custom endpoint specified) + region = self.client.meta.region_name or 'us-east-1' + + # AWS S3 virtual-hosted-style URLs + if region == "us-east-1": + bucket_url = f"https://{self.bucket}.s3.amazonaws.com" + else: + bucket_url = f"https://{self.bucket}.s3.{region}.amazonaws.com" + + return bucket_url + + except Exception as err: + log_message(LoggingScope.ERROR, "ERROR", "Error getting bucket URL: '%s'", str(err)) + return None + + @log_function_entry_exit() + def get_metadata(self, remote_path: str) -> Dict: + """ + Get metadata about an S3 object. + + Args: + remote_path: Path to the object in S3 + + Returns: + Dictionary containing object metadata, including 'ETag' key + """ + try: + log_message(LoggingScope.DEBUG, "DEBUG", "Getting metadata for S3 object: '%s'", remote_path) + response = self.client.head_object(Bucket=self.bucket, Key=remote_path) + log_message(LoggingScope.DEBUG, "DEBUG", "Retrieved metadata for '%s': '%s'", remote_path, response) + return response + except ClientError as err: + log_message(LoggingScope.ERROR, "ERROR", "Failed to get metadata for '%s': '%s'", remote_path, str(err)) + raise + + @log_function_entry_exit() + def _get_etag_file_path(self, local_path: str) -> Path: + """Get the path to the .etag file for a given local file.""" + return Path(local_path).with_suffix(".etag") + + @log_function_entry_exit() + def list_objects_v2(self, **kwargs): + """ + List objects in the bucket using the underlying boto3 client. + + Args: + **kwargs: Additional arguments to pass to boto3.client.list_objects_v2 + + Returns: + Response from boto3.client.list_objects_v2 + """ + return self.client.list_objects_v2(Bucket=self.bucket, **kwargs) + + @log_function_entry_exit() + def _read_etag(self, local_path: str) -> Optional[str]: + """Read the ETag from the .etag file if it exists.""" + etag_path = self._get_etag_file_path(local_path) + if etag_path.exists(): + try: + with open(etag_path, "r") as f: + return f.read().strip() + except Exception as e: + log_message(LoggingScope.DEBUG, "WARNING", "Failed to read ETag file '%s': '%s'", etag_path, str(e)) + return None + return None + + @log_function_entry_exit() + def _write_etag(self, local_path: str, etag: str) -> None: + """Write the ETag to the .etag file.""" + etag_path = self._get_etag_file_path(local_path) + try: + with open(etag_path, "w") as f: + f.write(etag) + log_message(LoggingScope.DEBUG, "DEBUG", "Wrote ETag to '%s'", etag_path) + except Exception as err: + log_message(LoggingScope.ERROR, "ERROR", "Failed to write ETag file '%s': '%s'", etag_path, str(err)) + # if we can't write the etag file, it's not critical + # the file will just be downloaded again next time diff --git a/scripts/automated_ingestion/eessi_task.py b/scripts/automated_ingestion/eessi_task.py new file mode 100644 index 00000000..86fbd8df --- /dev/null +++ b/scripts/automated_ingestion/eessi_task.py @@ -0,0 +1,1201 @@ +from enum import Enum, auto +from functools import total_ordering +from typing import Dict, List, Optional, Any + +import base64 +import os +import subprocess +import traceback + +from eessi_data_object import EESSIDataAndSignatureObject +from eessi_logging import log_function_entry_exit, log_message, LoggingScope +from eessi_task_action import EESSITaskAction +from eessi_task_description import EESSITaskDescription +from eessi_task_payload import EESSITaskPayload +from utils import send_slack_message + +from github import Github, GithubException, InputGitTreeElement, UnknownObjectException +from github.Branch import Branch +from github.PullRequest import PullRequest + + +@total_ordering +class EESSITaskState(Enum): + UNDETERMINED = auto() # The task state was not determined yet + NEW_TASK = auto() # The task has been created but not yet processed + PAYLOAD_STAGED = auto() # The task's payload has been staged to the Stratum-0 + PULL_REQUEST = auto() # A PR for the task has been created or updated in some staging repository + APPROVED = auto() # The PR for the task has been approved + REJECTED = auto() # The PR for the task has been rejected + INGESTED = auto() # The task's payload has been applied to the target CernVM-FS repository + DONE = auto() # The task has been completed + + @classmethod + def from_string( + cls, name: str, default: Optional["EESSITaskState"] = None, case_sensitive: bool = False + ) -> "EESSITaskState": + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "from_string: '%s'", name) + if case_sensitive: + to_return = cls.__members__.get(name, default) + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "from_string will return: '%s'", to_return) + return to_return + + try: + to_return = cls[name.upper()] + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "from_string will return: '%s'", to_return) + return to_return + except KeyError: + return default + + def __lt__(self, other): + if self.__class__ is other.__class__: + return self.value < other.value + return NotImplemented + + def __str__(self): + return self.name.upper() + + +class EESSITask: + description: EESSITaskDescription + payload: EESSITaskPayload + action: EESSITaskAction + git_repo: Github + config: Dict + + @log_function_entry_exit() + def __init__(self, description: EESSITaskDescription, config: Dict, cvmfs_repo: str, git_repo: Github): + self.description = description + self.config = config + self.cvmfs_repo = cvmfs_repo + self.git_repo = git_repo + self.action = self._determine_task_action() + + # define valid state transitions for all actions + # NOTE, for EESSITaskState.PULL_REQUEST, EESSITaskState.APPROVED must be the first element or + # _next_state() will not work correctly + self.valid_transitions = { + EESSITaskState.UNDETERMINED: [ + EESSITaskState.NEW_TASK, + EESSITaskState.PAYLOAD_STAGED, + EESSITaskState.PULL_REQUEST, + EESSITaskState.APPROVED, + EESSITaskState.REJECTED, + EESSITaskState.INGESTED, + EESSITaskState.DONE, + ], + EESSITaskState.NEW_TASK: [EESSITaskState.PAYLOAD_STAGED], + EESSITaskState.PAYLOAD_STAGED: [EESSITaskState.PULL_REQUEST], + EESSITaskState.PULL_REQUEST: [EESSITaskState.APPROVED, EESSITaskState.REJECTED], + EESSITaskState.APPROVED: [EESSITaskState.INGESTED], + EESSITaskState.REJECTED: [], # terminal state + EESSITaskState.INGESTED: [], # terminal state + EESSITaskState.DONE: [] # virtual terminal state, not used to write on GitHub + } + + self.payload = None + state = self.determine_state() + if state >= EESSITaskState.PAYLOAD_STAGED: + log_message(LoggingScope.TASK_OPS, "INFO", "initializing payload object in constructor for EESSITask") + self._init_payload_object() + + @log_function_entry_exit() + def _determine_task_action(self) -> EESSITaskAction: + """ + Determine the action type based on task description metadata. + """ + if "task" in self.description.metadata and "action" in self.description.metadata["task"]: + action_str = self.description.metadata["task"]["action"].lower() + if action_str == "nop": + return EESSITaskAction.NOP + elif action_str == "delete": + return EESSITaskAction.DELETE + elif action_str == "add": + return EESSITaskAction.ADD + elif action_str == "update": + return EESSITaskAction.UPDATE + # temporarily return EESSITaskAction.ADD as default because the metadata + # file does not yet have an action defined yet + return EESSITaskAction.ADD + + @log_function_entry_exit() + def _state_file_with_prefix_exists_in_repo_branch(self, file_path_prefix: str, branch_name: str = None) -> bool: + """ + Check if a file exists in a repository branch. + + Args: + file_path_prefix: the prefix of the file path + branch_name: the branch to check + + Returns: + True if a file with the prefix exists in the branch, False otherwise + """ + branch_name = self.git_repo.default_branch if branch_name is None else branch_name + # branch = self._get_branch_from_name(branch_name) + try: + # get all files in directory part of file_path_prefix + directory_part = os.path.dirname(file_path_prefix) + files = self.git_repo.get_contents(directory_part, ref=branch_name) + log_msg = "Found files '%s' in directory '%s' in branch '%s'" + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", log_msg, files, directory_part, branch_name) + # check if any of the files has file_path_prefix as prefix + for file in files: + if file.path.startswith(file_path_prefix): + log_msg = "Found file '%s' in directory '%s' in branch '%s'" + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", log_msg, file.path, directory_part, branch_name) + return True + log_msg = "No file with prefix '%s' found in directory '%s' in branch '%s'" + log_message(LoggingScope.TASK_OPS, "INFO", log_msg, file_path_prefix, directory_part, branch_name) + return False + except UnknownObjectException: + # file_path does not exist in branch + log_msg = "Directory '%s' or file with prefix '%s' does not exist in branch '%s'" + log_message(LoggingScope.TASK_OPS, "INFO", log_msg, directory_part, file_path_prefix, branch_name) + return False + except GithubException as err: + if err.status == 404: + # file_path does not exist in branch + log_msg = "Directory '%s' or file with prefix '%s' does not exist in branch '%s'" + log_message(LoggingScope.TASK_OPS, "INFO", log_msg, directory_part, file_path_prefix, branch_name) + return False + else: + # if there was some other (e.g. connection) issue, log message and return False + log_msg = "Unable to determine the state of '%s', the GitHub API returned status '%s'!" + log_message(LoggingScope.ERROR, "WARNING", log_msg, self.object, err.status) + return False + return False + + @log_function_entry_exit() + def _determine_sequence_numbers_including_task_file(self, repo: str, pr: str) -> Dict[int, bool]: + """ + Determines in which sequence numbers the metadata/task file is included and in which it is not. + NOTE, we only need to check the default branch of the repository, because a for a new task a file + is added to the default branch and for the subsequent processing of the task we use a different branch. + Thus, until the PR is closed, the task file stays in the default branch. + + Args: + repo: the repository name + pr: the pull request number + + Returns: + A dictionary with the sequence numbers as keys and a boolean value indicating if the metadata/task file is + included in that sequence number. + + Idea: + - The deployment for a single source PR could be split into multiple staging PRs each is assigned a unique + sequence number. + - For a given source PR (identified by the repo name and the PR number), a staging PR using a branch named + `REPO/PR_NUM/SEQ_NUM` is created. + - In the staging repo we create a corresponding directory `REPO/PR_NUM/SEQ_NUM`. + - If a metadata/task file is handled by the staging PR with sequence number, it is included in that directory. + - We iterate over all directories under `REPO/PR_NUM`: + - If the metadata/task file is available in the directory, we add the sequence number to the list. + + Note: this is a placeholder for now, as we do not know yet if we need to use a sequence number. + """ + sequence_numbers = {} + repo_pr_dir = f"{repo}/{pr}" + # iterate over all directories under repo_pr_dir + try: + directories = self._list_directory_contents(repo_pr_dir) + for dir in directories: + # check if the directory is a number + if dir.name.isdigit(): + # determine if a state file with prefix exists in the sequence number directory + # we need to use the basename of the remote file path + remote_file_path_basename = os.path.basename(self.description.task_object.remote_file_path) + state_file_name_prefix = f"{repo_pr_dir}/{dir.name}/{remote_file_path_basename}" + if self._state_file_with_prefix_exists_in_repo_branch(state_file_name_prefix): + sequence_numbers[int(dir.name)] = True + else: + sequence_numbers[int(dir.name)] = False + else: + # directory is not a number, so we skip it + continue + except FileNotFoundError: + # repo_pr_dir does not exist, so we return an empty dictionary + return {} + except GithubException as err: + if err.status != 404: # 404 is catched by FileNotFoundError + # some other error than the directory not existing + return {} + return sequence_numbers + + @log_function_entry_exit() + def _list_directory_contents(self, directory_path: str, branch_name: str = None) -> List[Any]: + """ + List the contents of a directory in a branch. + """ + try: + # Get contents of the directory + branch_name = self.git_repo.default_branch if branch_name is None else branch_name + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "listing contents of '%s' in branch '%s'", + directory_path, branch_name) + contents = self.git_repo.get_contents(directory_path, ref=branch_name) + + # If contents is a list, it means we successfully got directory contents + if isinstance(contents, list): + return contents + else: + # If it's not a list, it means the path is not a directory + raise ValueError(f"'{directory_path}' is not a directory") + except GithubException as err: + if err.status == 404: + raise FileNotFoundError(f"Directory not found: '{directory_path}'") + raise err + + @log_function_entry_exit() + def _next_state(self, state: EESSITaskState = None) -> EESSITaskState: + """ + Determine the next state based on the current state using the valid_transitions dictionary. + + NOTE, it assumes that function is only called for non-terminal states and that the next state is the first + element of the list returned by the valid_transitions dictionary. + """ + the_state = state if state is not None else self.determine_state() + return self.valid_transitions[the_state][0] + + @log_function_entry_exit() + def _path_exists_in_branch(self, path: str, branch_name: str = None) -> bool: + """ + Check if a path exists in a branch. + """ + branch_name = self.git_repo.default_branch if branch_name is None else branch_name + try: + self.git_repo.get_contents(path, ref=branch_name) + return True + except GithubException as err: + if err.status == 404: + return False + else: + raise err + + @log_function_entry_exit() + def _read_dict_from_string(self, content: str) -> dict: + """ + Read the dictionary from the string. + """ + config_dict = {} + for line in content.strip().split("\n"): + if "=" in line and not line.strip().startswith("#"): # Skip comments + key, value = line.split("=", 1) # Split only on first '=' + config_dict[key.strip()] = value.strip() + return config_dict + + @log_function_entry_exit() + def _read_pull_request_dir_from_file(self, task_pointer_file: str = None, branch_name: str = None) -> str: + """ + Read the pull request directory from the file in the given branch. + """ + # set default values for task pointer file and branch name + if task_pointer_file is None: + task_pointer_file = self.description.task_object.remote_file_path + if branch_name is None: + branch_name = self.git_repo.default_branch + log_message(LoggingScope.TASK_OPS, "INFO", "reading pull request directory from file '%s' in branch '%s'", + task_pointer_file, branch_name) + + # read the pull request directory from the file in the given branch + content = self.git_repo.get_contents(task_pointer_file, ref=branch_name) + + # Decode the content from base64 + content_str = content.decoded_content.decode("utf-8") + + # Parse into dictionary + config_dict = self._read_dict_from_string(content_str) + + target_dir = config_dict.get("target_dir", None) + return config_dict.get("pull_request_dir", target_dir) + + @log_function_entry_exit() + def _determine_pull_request_dir(self, task_pointer_file: str = None, branch_name: str = None) -> str: + """Determine the pull request directory via the task pointer file""" + return self._read_pull_request_dir_from_file(task_pointer_file=task_pointer_file, branch_name=branch_name) + + @log_function_entry_exit() + def _get_branch_from_name(self, branch_name: str = None) -> Optional[Branch]: + """ + Get a branch object from its name. + """ + branch_name = self.git_repo.default_branch if branch_name is None else branch_name + + try: + branch = self.git_repo.get_branch(branch_name) + log_message(LoggingScope.TASK_OPS, "INFO", "branch '%s' exists: '%s'", branch_name, branch) + return branch + except Exception as err: + log_message(LoggingScope.TASK_OPS, "ERROR", "error checking if branch '%s' exists: '%s'", + branch_name, err) + return None + + @log_function_entry_exit() + def _read_task_state_from_file(self, path: str, branch_name: str = None) -> EESSITaskState: + """ + Read the task state from the file in the given branch. + """ + branch_name = self.git_repo.default_branch if branch_name is None else branch_name + content = self.git_repo.get_contents(path, ref=branch_name) + + # Decode the content from base64 + content_str = content.decoded_content.decode("utf-8").strip() + log_message(LoggingScope.TASK_OPS, "INFO", "content in TaskState file: '%s'", content_str) + + task_state = EESSITaskState.from_string(content_str) + log_message(LoggingScope.TASK_OPS, "INFO", "task state: '%s'", task_state) + + return task_state + + @log_function_entry_exit() + def determine_state(self, branch: str = None) -> EESSITaskState: + """ + Determine the state of the task based on the state of the staging repository. + """ + # check if path representing the task file exists in the default branch or the "feature" branch + task_pointer_file = self.description.task_object.remote_file_path + branch_to_use = self.git_repo.default_branch if branch is None else branch + + if self._path_exists_in_branch(task_pointer_file, branch_name=branch_to_use): + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "path '%s' exists in branch '%s'", + task_pointer_file, branch_to_use) + + # get state from task file in branch to use + # - read the EESSITaskState file in pull request directory + pull_request_dir = self._determine_pull_request_dir(branch_name=branch_to_use) + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "pull request directory: '%s'", pull_request_dir) + task_state_file_path = f"{pull_request_dir}/TaskState" + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "task state file path: '%s'", task_state_file_path) + task_state = self._read_task_state_from_file(task_state_file_path, branch_to_use) + + log_message(LoggingScope.TASK_OPS, "INFO", "task state in branch '%s': %s", + branch_to_use, task_state) + return task_state + else: + log_message(LoggingScope.TASK_OPS, "INFO", "path '%s' does not exist in branch '%s'", + task_pointer_file, branch_to_use) + return EESSITaskState.UNDETERMINED + + @log_function_entry_exit() + def handle(self): + """ + Dynamically find and execute the appropriate handler based on action and state. + """ + state_before_handle = self.determine_state() + + # Construct handler method name + handler_name = f"_handle_{self.action}_{str(state_before_handle).lower()}" + + # Check if the handler exists + handler = getattr(self, handler_name, None) + + if handler and callable(handler): + # Execute the handler if it exists + return handler() + else: + # Default behavior for missing handlers + log_message(LoggingScope.TASK_OPS, "ERROR", + "No handler for action '%s' and state '%s' implemented; nothing to be done", + self.action, state_before_handle) + return state_before_handle + + # Implement handlers for ADD action + @log_function_entry_exit() + def _safe_create_file(self, path: str, message: str, content: str, branch_name: str = None): + """Create a file in the given branch.""" + try: + branch_name = self.git_repo.default_branch if branch_name is None else branch_name + existing_file = self.git_repo.get_contents(path, ref=branch_name) + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "File '%s' already exists", path) + return existing_file + except GithubException as err: + if err.status == 404: # File doesn't exist + # Safe to create + return self.git_repo.create_file(path, message, content, branch=branch_name) + else: + raise err # Some other error + + @log_function_entry_exit() + def _create_multi_file_commit(self, files_data, commit_message, branch_name: str = None): + """ + Create a commit with multiple file changes + + files_data: dict with structure: + { + "path/to/file1.txt": { + "content": "file content", + "mode": "100644" # optional, defaults to 100644 + }, + "path/to/file2.py": { + "content": "print('hello')", + "mode": "100644" + } + } + """ + branch_name = self.git_repo.default_branch if branch_name is None else branch_name + ref = self.git_repo.get_git_ref(f"heads/{branch_name}") + current_commit = self.git_repo.get_git_commit(ref.object.sha) + base_tree = current_commit.tree + + # Create tree elements + tree_elements = [] + for file_path, file_info in files_data.items(): + content = file_info["content"] + if isinstance(content, str): + content = content.encode("utf-8") + + blob = self.git_repo.create_git_blob( + base64.b64encode(content).decode("utf-8"), + "base64" + ) + tree_elements.append(InputGitTreeElement( + path=file_path, + mode=file_info.get("mode", "100644"), + type="blob", + sha=blob.sha + )) + + # Create new tree + new_tree = self.git_repo.create_git_tree(tree_elements, base_tree) + + # Create commit + new_commit = self.git_repo.create_git_commit( + commit_message, + new_tree, + [current_commit] + ) + + # Update branch reference + ref.edit(new_commit.sha) + + return new_commit + + @log_function_entry_exit() + def _update_file( + self, file_path: str, new_content: str, commit_message: str, branch_name: str = None + ) -> Optional[Dict]: + try: + branch_name = self.git_repo.default_branch if branch_name is None else branch_name + + # get the current file + file = self.git_repo.get_contents(file_path, ref=branch_name) + + # update the file + result = self.git_repo.update_file( + path=file_path, + message=commit_message, + content=new_content, + sha=file.sha, + branch=branch_name + ) + + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", + "File updated successfully. Commit SHA: '%s'", result["commit"].sha) + return result + + except Exception as err: + log_message(LoggingScope.TASK_OPS, "ERROR", "Error updating file: '%s'", err) + return None + + @log_function_entry_exit() + def _sorted_list_of_sequence_numbers(self) -> List[int]: + """Create a sorted list of sequence numbers from the pull requests directory""" + # a pull request's directory is of the form REPO/PR/SEQ + # hence, we can get all sequence numbers from the pull requests directory REPO/PR + sequence_numbers = [] + repo_pr_dir = f"{self.description.get_repo_name()}/{self.description.get_pr_number()}" + + # iterate over all directories under repo_pr_dir + try: + directories = self._list_directory_contents(repo_pr_dir) + for dir in directories: + # check if the directory is a number + if dir.name.isdigit(): + sequence_numbers.append(int(dir.name)) + else: + # directory is not a number, so we skip it + continue + except FileNotFoundError: + # repo_pr_dir does not exist, so we return an empty dictionary + log_message(LoggingScope.TASK_OPS, "ERROR", "Pull requests directory '%s' does not exist", repo_pr_dir) + except GithubException as err: + if err.status != 404: # 404 is catched by FileNotFoundError + # some other error than the directory not existing + log_message(LoggingScope.TASK_OPS, "ERROR", + "Some other error than the directory not existing: '%s'", err) + except Exception as err: + log_message(LoggingScope.TASK_OPS, "ERROR", "Unexpected error: '%s'", err) + + return sorted(sequence_numbers) + + @log_function_entry_exit() + def _determine_sequence_number(self) -> int: + """Determine the sequence number for the task""" + + sequence_numbers = self._sorted_list_of_sequence_numbers() + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "number of sequence numbers: %d", len(sequence_numbers)) + if len(sequence_numbers) == 0: + log_message(LoggingScope.TASK_OPS, "INFO", "no sequence numbers found, returning 0") + return 0 + + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", + "sequence numbers: [%s]", ", ".join(map(str, sequence_numbers))) + + # get the highest sequence number + highest_sequence_number = sequence_numbers[-1] + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "highest sequence number: %d", highest_sequence_number) + + pull_request = self._find_pr_for_sequence_number(highest_sequence_number) + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "pull request: '%s'", pull_request) + + if pull_request is None: + log_message(LoggingScope.TASK_OPS, "INFO", "Did not find pull request for sequence number %d", + highest_sequence_number) + # the directory for the sequence number exists but no PR yet + return highest_sequence_number + else: + log_message(LoggingScope.TASK_OPS, "INFO", "pull request found: '%s'", pull_request) + log_message(LoggingScope.TASK_OPS, "INFO", "pull request state/merged: '%s/%s'", + pull_request.state, str(pull_request.is_merged())) + if pull_request.is_merged(): + # the PR is merged, so we use the next sequence number + return highest_sequence_number + 1 + else: + # the PR is not merged, it may be closed though + if pull_request.state == 'closed': + # PR has been closed, so we return the next sequence number + return highest_sequence_number + 1 + else: + # PR is not closed, so we return the current highest sequence number + return highest_sequence_number + + @log_function_entry_exit() + def _handle_add_undetermined(self): + """Handler for ADD action in UNDETERMINED state""" + log_message(LoggingScope.TASK_OPS, "INFO", "Handling ADD action in UNDETERMINED state: '%s'", + self.description.get_task_file_name()) + # task is in state UNDETERMINED if there is no pull request directory for the task yet + # + # create pull request directory (REPO/PR/SEQ/TASK_FILE_NAME/) + # create task file in pull request directory (PULL_REQUEST_DIR/TaskDescription) + # create task status file in pull request directory (PULL_REQUEST_DIR/TaskState) + # create pointer file from task file path to pull request directory (remote_file_path -> PULL_REQUEST_DIR) + repo_name = self.description.get_repo_name() + pr_number = self.description.get_pr_number() + sequence_number = self._determine_sequence_number() # corresponds to an open or yet to be created PR + task_file_name = self.description.get_task_file_name() + # we cannot use self._determine_pull_request_dir() here because it requires a task pointer file + # and we don't have one yet + pull_request_dir = f"{repo_name}/{pr_number}/{sequence_number}/{task_file_name}" + task_description_file_path = f"{pull_request_dir}/TaskDescription" + task_state_file_path = f"{pull_request_dir}/TaskState" + remote_file_path = self.description.task_object.remote_file_path + + files_to_commit = { + task_description_file_path: { + "content": self.description.get_contents(), + "mode": "100644" + }, + task_state_file_path: { + "content": f"{EESSITaskState.NEW_TASK.name}\n", + "mode": "100644" + }, + remote_file_path: { + "content": f"remote_file_path = {remote_file_path}\npull_request_dir = {pull_request_dir}", + "mode": "100644" + } + } + + branch_name = self.git_repo.default_branch + try: + commit = self._create_multi_file_commit( + files_to_commit, + f"new task for {repo_name} PR {pr_number} seq {sequence_number}", + branch_name=branch_name + ) + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "commit created: '%s'", commit) + except Exception as err: + log_message(LoggingScope.TASK_OPS, "ERROR", "Error creating commit: '%s'", err) + # TODO: rollback previous changes (task description file, task state file) + return EESSITaskState.UNDETERMINED + + # TODO: verify that the sequence number is still valid (PR corresponding to the sequence number + # is still open or yet to be created); if it is not valid, perform corrective actions + return EESSITaskState.NEW_TASK + + @log_function_entry_exit() + def _update_task_state_file(self, next_state: EESSITaskState, branch_name: str = None) -> Optional[Dict]: + """Update the TaskState file content in default or given branch""" + branch_name = self.git_repo.default_branch if branch_name is None else branch_name + + task_pointer_file = self.description.task_object.remote_file_path + pull_request_dir = self._read_pull_request_dir_from_file(task_pointer_file, branch_name) + task_state_file_path = f"{pull_request_dir}/TaskState" + arch = self.description.get_metadata_filename_components()[3] + commit_message = f"change task state to {next_state} in {branch_name} for {arch}" + result = self._update_file(task_state_file_path, + f"{next_state.name}\n", + commit_message, + branch_name=branch_name) + return result + + @log_function_entry_exit() + def _init_payload_object(self): + """Initialize the payload object""" + if self.payload is not None: + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "payload object already initialized") + return + + # get name of of payload from metadata + payload_name = self.description.metadata["payload"]["filename"] + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "payload_name: '%s'", payload_name) + + # get config and remote_client from self.description.task_object + config = self.description.task_object.config + remote_client = self.description.task_object.remote_client + + # determine remote_file_path by replacing basename of remote_file_path in self.description.task_object + # with payload_name + description_remote_file_path = self.description.task_object.remote_file_path + payload_remote_file_path = os.path.join(os.path.dirname(description_remote_file_path), payload_name) + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "payload_remote_file_path: '%s'", payload_remote_file_path) + + # initialize payload object + payload_object = EESSIDataAndSignatureObject(config, payload_remote_file_path, remote_client) + self.payload = EESSITaskPayload(payload_object) + log_message(LoggingScope.TASK_OPS, "INFO", "payload: '%s'", self.payload) + + @log_function_entry_exit() + def _handle_add_new_task(self): + """Handler for ADD action in NEW_TASK state""" + log_message(LoggingScope.TASK_OPS, "INFO", "Handling ADD action in NEW_TASK state: '%s'", + self.description.get_task_file_name()) + # determine next state + next_state = self._next_state(EESSITaskState.NEW_TASK) + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "next_state: '%s'", next_state) + + # initialize payload object + self._init_payload_object() + + # update TaskState file content + self._update_task_state_file(next_state) + + # TODO: verify that the sequence number is still valid (PR corresponding to the sequence number + # is still open or yet to be created); if it is not valid, perform corrective actions + return next_state + + @log_function_entry_exit() + def _find_pr_for_branch(self, branch_name: str) -> Optional[PullRequest]: + """ + Find the single PR for the given branch in any state. + + Args: + repo: GitHub repository + branch_name: Name of the branch + + Returns: + PullRequest object if found, None otherwise + """ + try: + prs = [pr for pr in list(self.git_repo.get_pulls(state="all")) + if pr.head.ref == branch_name] + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "number of PRs found: %d", len(prs)) + if len(prs): + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", + "1st PR found: %d, '%s'", prs[0].number, prs[0].head.ref) + return prs[0] if prs else None + except Exception as err: + log_message(LoggingScope.TASK_OPS, "ERROR", "Error finding PR for branch '%s': '%s'", branch_name, err) + return None + + @log_function_entry_exit() + def _find_pr_for_sequence_number(self, sequence_number: int) -> Optional[PullRequest]: + """Find the PR for the given sequence number""" + repo_name = self.description.get_repo_name() + pr_number = self.description.get_pr_number() + feature_branch_name = f"{repo_name.replace('/', '-')}-PR-{pr_number}-SEQ-{sequence_number}" + + # list all PRs with head_ref starting with the feature branch name without the sequence number + last_dash = feature_branch_name.rfind("-") + if last_dash != -1: + head_ref_wout_seq_num = feature_branch_name[:last_dash + 1] # +1 to include the separator + else: + head_ref_wout_seq_num = feature_branch_name + + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", + "searching for PRs whose head_ref starts with: '%s'", head_ref_wout_seq_num) + + all_prs = [pr for pr in list(self.git_repo.get_pulls(state="all")) + if pr.head.ref.startswith(head_ref_wout_seq_num)] + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", " number of PRs found: %d", len(all_prs)) + for pr in all_prs: + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", " PR #%d: '%s'", pr.number, pr.head.ref) + + # now, find the PR for the feature branch name (if any) + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", + "searching PR for feature branch name: '%s'", feature_branch_name) + pull_request = self._find_pr_for_branch(feature_branch_name) + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "pull request for branch '%s': '%s'", + feature_branch_name, pull_request) + return pull_request + + @log_function_entry_exit() + def _determine_sequence_number_from_pull_request_directory(self) -> int: + """Determine the sequence number from the pull request directory name""" + task_pointer_file = self.description.task_object.remote_file_path + pull_request_dir = self._read_pull_request_dir_from_file(task_pointer_file, self.git_repo.default_branch) + # pull_request_dir is of the form REPO/PR/SEQ/TASK_FILE_NAME/ (REPO contains a '/' separating the org and repo) + _, _, _, seq, _ = pull_request_dir.split("/") + return int(seq) + + @log_function_entry_exit() + def _determine_feature_branch_name(self) -> str: + """Determine the feature branch name from the pull request directory name""" + task_pointer_file = self.description.task_object.remote_file_path + pull_request_dir = self._read_pull_request_dir_from_file(task_pointer_file, self.git_repo.default_branch) + # pull_request_dir is of the form REPO/PR/SEQ/TASK_FILE_NAME/ (REPO contains a '/' separating the org and repo) + org, repo, pr, seq, _ = pull_request_dir.split("/") + return f"{org}-{repo}-PR-{pr}-SEQ-{seq}" + + @log_function_entry_exit() + def _update_task_states(self, next_state: EESSITaskState, default_branch_name: str, + approved_state: EESSITaskState, feature_branch_name: str): + """ + Update task states in default and feature branches + + States have to be updated in a specific order and in particular the default branch has to be + merged into the feature branch before the feature branch can be updated to avoid a merge conflict. + + Args: + next_state: next state to be applied to the default branch + default_branch_name: name of the default branch + approved_state: state to be applied to the feature branch + feature_branch_name: name of the feature branch + """ + # TODO: add failure handling (capture failures and return them somehow) + + # update TaskState file content + # - next_state in default branch (interpreted as current state) + # - approved_state in feature branch (interpreted as future state, ie, after + # the PR corresponding to the feature branch will be merged) + + # first, update the task state file in the default branch + self._update_task_state_file(next_state, branch_name=default_branch_name) + + # second, merge default branch into feature branch (to avoid a merge conflict) + # TODO: store arch info (CPU+ACCEL) in task/metdata file and then access that rather + # than using a part of the file name + arch = self.description.get_metadata_filename_components()[3] + commit_message = f"merge {default_branch_name} into {feature_branch_name} for {arch}" + self.git_repo.merge( + head=default_branch_name, + base=feature_branch_name, + commit_message=commit_message + ) + + # last, update task state file in feature branch + self._update_task_state_file(approved_state, branch_name=feature_branch_name) + log_message(LoggingScope.TASK_OPS, "INFO", + "TaskState file updated to '%s' in default branch '%s' and to '%s' in feature branch '%s'", + next_state, default_branch_name, approved_state, feature_branch_name) + + @log_function_entry_exit() + def _create_task_summary(self) -> str: + """Analyse contents of current task and create a file for it in the REPO-PR-SEQ directory.""" + + # determine task summary file path in feature branch on GitHub + feature_branch_name = self._determine_feature_branch_name() + pull_request_dir = self._determine_pull_request_dir(branch_name=feature_branch_name) + task_summary_file_path = f"{pull_request_dir}/TaskSummary.html" + + # check if task summary file already exists in repo on GitHub + if self._path_exists_in_branch(task_summary_file_path, feature_branch_name): + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", + "task summary file already exists: '%s'", task_summary_file_path) + task_summary = self.git_repo.get_contents(task_summary_file_path, ref=feature_branch_name) + # return task_summary.decoded_content + return task_summary + + # create task summary + payload_name = self.description.metadata["payload"]["filename"] + payload_summary = self.payload.analyse_contents(self.config) + metadata_contents = self.description.get_contents() + + task_summary = self.config["github"]["task_summary_payload_template"].format( + payload_name=payload_name, + metadata_contents=metadata_contents, + payload_overview=payload_summary + ) + + # create HTML file with task summary in REPO-PR-SEQ directory + # TODO: add failure handling (capture result and act on it) + task_file_name = self.description.get_task_file_name() + commit_message = f"create summary for {task_file_name} in {feature_branch_name}" + self._safe_create_file(task_summary_file_path, commit_message, task_summary, + branch_name=feature_branch_name) + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "task summary file created: '%s'", task_summary_file_path) + + # return task summary + return task_summary + + @log_function_entry_exit() + def _create_pr_contents_overview(self) -> str: + """Create a contents overview for the pull request""" + # TODO: implement + feature_branch_name = self._determine_feature_branch_name() + task_pointer_file = self.description.task_object.remote_file_path + pull_request_dir = self._read_pull_request_dir_from_file(task_pointer_file, feature_branch_name) + pr_dir = os.path.dirname(pull_request_dir) + directories = self._list_directory_contents(pr_dir, feature_branch_name) + contents_overview = "" + if directories: + contents_overview += "\n" + for directory in directories: + task_summary_file_path = f"{pr_dir}/{directory.name}/TaskSummary.html" + if self._path_exists_in_branch(task_summary_file_path, feature_branch_name): + file_contents = self.git_repo.get_contents(task_summary_file_path, ref=feature_branch_name) + task_summary = base64.b64decode(file_contents.content).decode("utf-8") + contents_overview += f"{task_summary}\n" + else: + contents_overview += f"Task summary file not found: {task_summary_file_path}\n" + contents_overview += "\n" + else: + contents_overview += "No tasks found in this PR\n" + + print(f"contents_overview: {contents_overview}") + return contents_overview + + @log_function_entry_exit() + def _create_pull_request(self, feature_branch_name: str, default_branch_name: str): + """ + Create a PR from the feature branch to the default branch + + Args: + feature_branch_name: name of the feature branch + default_branch_name: name of the default branch + """ + pr_title_format = self.config["github"]["grouped_pr_title"] + pr_body_format = self.config["github"]["grouped_pr_body"] + repo_name = self.description.get_repo_name() + pr_number = self.description.get_pr_number() + pr_url = f"https://github.com/{repo_name}/pull/{pr_number}" + seq_num = self._determine_sequence_number_from_pull_request_directory() + pr_title = pr_title_format.format( + cvmfs_repo=self.cvmfs_repo, + pr=pr_number, + repo=repo_name, + seq_num=seq_num, + ) + self._create_task_summary() + contents_overview = self._create_pr_contents_overview() + pr_body = pr_body_format.format( + cvmfs_repo=self.cvmfs_repo, + pr=pr_number, + pr_url=pr_url, + repo=repo_name, + seq_num=seq_num, + contents=contents_overview, + analysis="
TO BE DONE
", + action="
TO BE DONE
", + ) + pr = self.git_repo.create_pull( + title=pr_title, + body=pr_body, + head=feature_branch_name, + base=default_branch_name + ) + log_message(LoggingScope.TASK_OPS, "INFO", "PR created: '%s'", pr) + + @log_function_entry_exit() + def _update_pull_request(self, pull_request: PullRequest): + """ + Update the pull request + + Args: + pull_request: instance of the pull request + """ + # TODO: update sections (contents analysis, action) + repo_name = self.description.get_repo_name() + pr_number = self.description.get_pr_number() + pr_url = f"https://github.com/{repo_name}/pull/{pr_number}" + seq_num = self._determine_sequence_number_from_pull_request_directory() + + self._create_task_summary() + contents_overview = self._create_pr_contents_overview() + pr_body_format = self.config["github"]["grouped_pr_body"] + pr_body = pr_body_format.format( + cvmfs_repo=self.cvmfs_repo, + pr=pr_number, + pr_url=pr_url, + repo=repo_name, + seq_num=seq_num, + contents=contents_overview, + analysis="
TO BE DONE
", + action="
TO BE DONE
", + ) + pull_request.edit(body=pr_body) + + log_message(LoggingScope.TASK_OPS, "INFO", "PR updated: '%s'", pull_request) + + @log_function_entry_exit() + def _handle_add_payload_staged(self): + """Handler for ADD action in PAYLOAD_STAGED state""" + log_message(LoggingScope.TASK_OPS, "INFO", "Handling ADD action in PAYLOAD_STAGED state: '%s'", + self.description.get_task_file_name()) + next_state = self._next_state(EESSITaskState.PAYLOAD_STAGED) + approved_state = EESSITaskState.APPROVED + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", + "next_state: '%s', approved_state: '%s'", next_state, approved_state) + + default_branch_name = self.git_repo.default_branch + default_branch = self._get_branch_from_name(default_branch_name) + default_sha = default_branch.commit.sha + feature_branch_name = self._determine_feature_branch_name() + feature_branch = self._get_branch_from_name(feature_branch_name) + if not feature_branch: + # feature branch does not exist + # TODO: could have been merged already --> check if PR corresponding to the feature branch exists + # ASSUME: it has not existed before --> create it + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", + "branch '%s' does not exist, creating it", feature_branch_name) + + feature_branch = self.git_repo.create_git_ref(f"refs/heads/{feature_branch_name}", default_sha) + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", + "branch '%s' created: '%s'", feature_branch_name, feature_branch) + else: + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", + "found existing branch for '%s': '%s'", feature_branch_name, feature_branch) + + pull_request = self._find_pr_for_branch(feature_branch_name) + if not pull_request: + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", + "no PR found for branch '%s'", feature_branch_name) + + # TODO: add failure handling (capture result and act on it) + self._update_task_states(next_state, default_branch_name, approved_state, feature_branch_name) + + # TODO: add failure handling (capture result and act on it) + self._create_pull_request(feature_branch_name, default_branch_name) + + return EESSITaskState.PULL_REQUEST + else: + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", + "found existing PR for branch '%s': '%s'", feature_branch_name, pull_request) + # TODO: check if PR is open or closed + if pull_request.state == "closed": + log_message(LoggingScope.TASK_OPS, "INFO", + "PR '%s' is closed, creating issue", pull_request) + # TODO: create issue + return EESSITaskState.PAYLOAD_STAGED + else: + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", + "PR '%s' is open, updating task states", pull_request) + # TODO: add failure handling (capture result and act on it) + # THINK about what a failure would mean and what to do about it. + self._update_task_states(next_state, default_branch_name, approved_state, feature_branch_name) + + # TODO: add failure handling (capture result and act on it) + self._update_pull_request(pull_request) + + return EESSITaskState.PULL_REQUEST + + @log_function_entry_exit() + def _handle_add_pull_request(self): + """Handler for ADD action in PULL_REQUEST state""" + log_message(LoggingScope.TASK_OPS, "INFO", "Handling ADD action in PULL_REQUEST state: '%s'", + self.description.get_task_file_name()) + # Implementation for adding in PULL_REQUEST state + # we got here because the state of the task is PULL_REQUEST in the default branch + # determine branch and PR and state of PR + # PR is open --> just return EESSITaskState.PULL_REQUEST + # PR is closed & merged --> deployment is approved + # PR is closed & not merged --> deployment is rejected + feature_branch_name = self._determine_feature_branch_name() + # TODO: check if feature branch exists, for now ASSUME it does + pull_request = self._find_pr_for_branch(feature_branch_name) + if pull_request: + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", + "found PR for branch '%s': '%s'", feature_branch_name, pull_request) + if pull_request.state == "closed": + if pull_request.merged: + log_message(LoggingScope.TASK_OPS, "INFO", + "PR '%s' is closed and merged (strange that state is PULL_REQUEST)", pull_request) + # TODO: How could we ended up here? state in default branch is PULL_REQUEST but + # PR is merged, hence it should have been in the APPROVED state + # ==> for now, just return EESSITaskState.PULL_REQUEST + # + # there is the possibility that the PR was updated just before the + # PR was merged + # WHY is it a problem? because a task may have been accepted that wouldn't + # have been accepted or worse shouldn't been accepted + # WHAT to do? ACCEPT/IGNORE THE ISSUE FOR NOw + # HOWEVER, the contents of the PR directory may be inconsistent with + # respect to the TaskState file and missing TaskSummary.html file + # WE could create an issue and only return EESSITaskState.APPROVED if the + # issue is closed + # WE could also defer all handling of this to the handler for the + # APPROVED state + # NOPE, we have to do some handling here, at least for the tasks where their + # state file did + # --> check if we could have ended up here? If so, create an issue. + # Do we need a state ISSUE_OPENED to avoid processing the task again? + return EESSITaskState.PULL_REQUEST + else: + log_message(LoggingScope.TASK_OPS, "INFO", + "PR '%s' is closed and not merged, returning REJECTED state", pull_request) + # TODO: there is the possibility that the PR was updated just before the + # PR was closed + # WHY is it a problem? because a task may have been rejected that wouldn't + # have been rejected or worse shouldn't been rejected + # WHAT to do? ACCEPT/IGNORE THE ISSUE FOR NOw + # HOWEVER, the contents of the PR directory may be inconsistent with + # respect to the TaskState file and missing TaskSummary.html file + # WE could create an issue and only return EESSITaskState.REJECTED if the + # issue is closed + # WE could also defer all handling of this to the handler for the + # REJECTED state + # FOR NOW, we assume that the task was rejected on purpose + # we need to change the state of the task in the default branch to REJECTED + self._update_task_state_file(EESSITaskState.REJECTED) + return EESSITaskState.REJECTED + else: + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", + "PR '%s' is open, returning PULL_REQUEST state", pull_request) + return EESSITaskState.PULL_REQUEST + else: + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", + "no PR found for branch '%s'", feature_branch_name) + # the method was called because the state of the task is PULL_REQUEST in the default branch + # however, it's weird that the PR was not found for the feature branch + # TODO: may create or update an issue for the task or deployment + return EESSITaskState.PULL_REQUEST + + return EESSITaskState.PULL_REQUEST + + @log_function_entry_exit() + def _perform_task_action(self) -> bool: + """Perform the task action""" + # TODO: support other actions than ADD + if self.action == EESSITaskAction.ADD: + return self._perform_task_add() + else: + raise ValueError(f"Task action '{self.action}' not supported (yet)") + + @log_function_entry_exit() + def _issue_exists(self, title: str, state: str = "open") -> bool: + """ + Check if an issue with the given title and state already exists. + """ + issues = self.git_repo.get_issues(state=state) + for issue in issues: + if issue.title == title and issue.state == state: + return True + else: + return False + + @log_function_entry_exit() + def _perform_task_add(self) -> bool: + """Perform the ADD task action""" + # TODO: verify checksum here or before? + script = self.config["paths"]["ingestion_script"] + sudo = ["sudo"] if self.config["cvmfs"].getboolean("ingest_as_root", True) else [] + log_message(LoggingScope.STATE_OPS, "INFO", + "Running the ingestion script for '%s'...\n with script: '%s'\n with sudo: '%s'", + self.description.get_task_file_name(), + script, "no" if sudo == [] else "yes") + ingest_cmd = subprocess.run( + sudo + [script, self.cvmfs_repo, str(self.payload.payload_object.local_file_path)], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + log_message(LoggingScope.STATE_OPS, "INFO", + "Ingestion script returned code '%s'", ingest_cmd.returncode) + log_message(LoggingScope.STATE_OPS, "INFO", + "Ingestion script stdout: '%s'", ingest_cmd.stdout.decode("UTF-8")) + log_message(LoggingScope.STATE_OPS, "INFO", + "Ingestion script stderr: '%s'", ingest_cmd.stderr.decode("UTF-8")) + if ingest_cmd.returncode == 0: + next_state = self._next_state(EESSITaskState.APPROVED) + self._update_task_state_file(next_state) + if self.config.has_section("slack") and self.config["slack"].getboolean("ingestion_notification", False): + send_slack_message( + self.config["secrets"]["slack_webhook"], + self.config["slack"]["ingestion_message"].format( + tarball=os.path.basename(self.payload.payload_object.local_file_path), + cvmfs_repo=self.cvmfs_repo) + ) + return True + else: + tarball = os.path.basename(self.payload.payload_object.local_file_path) + log_message(LoggingScope.STATE_OPS, "ERROR", + "Failed to add '%s', return code '%s'", + tarball, + ingest_cmd.returncode) + + issue_title = f"Failed to add '{tarball}'" + log_message(LoggingScope.STATE_OPS, "INFO", + "Creating issue for failed ingestion: title: '%s'", + issue_title) + + command = " ".join(ingest_cmd.args) + failed_ingestion_issue_body = self.config["github"]["failed_ingestion_issue_body"] + issue_body = failed_ingestion_issue_body.format( + command=command, + tarball=tarball, + return_code=ingest_cmd.returncode, + stdout=ingest_cmd.stdout.decode("UTF-8"), + stderr=ingest_cmd.stderr.decode("UTF-8") + ) + log_message(LoggingScope.STATE_OPS, "INFO", + "Creating issue for failed ingestion: body: '%s'", + issue_body) + + if self._issue_exists(issue_title, state="open"): + log_message(LoggingScope.STATE_OPS, "INFO", + "Failed to add '%s', but an open issue already exists, skipping...", + os.path.basename(self.payload.payload_object.local_file_path)) + else: + log_message(LoggingScope.STATE_OPS, "INFO", + "Failed to add '%s', but an open issue does not exist, creating one...", + os.path.basename(self.payload.payload_object.local_file_path)) + self.git_repo.create_issue(title=issue_title, body=issue_body) + return False + + @log_function_entry_exit() + def _handle_add_approved(self): + """Handler for ADD action in APPROVED state""" + log_message(LoggingScope.TASK_OPS, "INFO", "Handling ADD action in APPROVED state: '%s'", + self.description.get_task_file_name()) + # Implementation for adding in APPROVED state + # If successful, _perform_task_action() will change the state + # to INGESTED on GitHub + try: + if self._perform_task_action(): + return EESSITaskState.INGESTED + else: + return EESSITaskState.APPROVED + except Exception as err: + log_message(LoggingScope.TASK_OPS, "ERROR", + "Error performing task action: '%s'\nTraceback:\n%s", err, traceback.format_exc()) + return EESSITaskState.APPROVED + + @log_function_entry_exit() + def _handle_add_ingested(self): + """Handler for ADD action in INGESTED state""" + log_message(LoggingScope.TASK_OPS, "INFO", "Handling ADD action in INGESTED state: '%s'", + self.description.get_task_file_name()) + # Implementation for adding in INGESTED state + # DONT change state on GitHub, because the result + # (INGESTED/REJECTED) would be overwritten + return EESSITaskState.DONE + + @log_function_entry_exit() + def _handle_add_rejected(self): + """Handler for ADD action in REJECTED state""" + log_message(LoggingScope.TASK_OPS, "INFO", "Handling ADD action in REJECTED state: '%s'", + self.description.get_task_file_name()) + # Implementation for adding in REJECTED state + # DONT change state on GitHub, because the result + # (INGESTED/REJECTED) would be overwritten + return EESSITaskState.DONE + + @log_function_entry_exit() + def __str__(self): + return f"EESSITask(description={self.description}, action={self.action}, state={self.determine_state()})" diff --git a/scripts/automated_ingestion/eessi_task_action.py b/scripts/automated_ingestion/eessi_task_action.py new file mode 100644 index 00000000..6f141435 --- /dev/null +++ b/scripts/automated_ingestion/eessi_task_action.py @@ -0,0 +1,12 @@ +from enum import Enum, auto + + +class EESSITaskAction(Enum): + NOP = auto() # perform no action + DELETE = auto() # perform a delete operation + ADD = auto() # perform an add operation + UPDATE = auto() # perform an update operation + UNKNOWN = auto() # unknown action + + def __str__(self): + return self.name.lower() diff --git a/scripts/automated_ingestion/eessi_task_description.py b/scripts/automated_ingestion/eessi_task_description.py new file mode 100644 index 00000000..24cc6df7 --- /dev/null +++ b/scripts/automated_ingestion/eessi_task_description.py @@ -0,0 +1,188 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Tuple + +import json + +from eessi_data_object import EESSIDataAndSignatureObject +from eessi_logging import log_function_entry_exit, log_message, LoggingScope +from eessi_remote_storage_client import DownloadMode + + +@dataclass +class EESSITaskDescription: + """Class representing an EESSI task to be performed, including its metadata and associated data files.""" + + # The EESSI data and signature object associated with this task + task_object: EESSIDataAndSignatureObject + + # Whether the signature was successfully verified + signature_verified: bool = False + + # Metadata from the task description file + metadata: Dict[str, Any] = None + + # task element + task: Dict[str, Any] = None + + # source element + source: Dict[str, Any] = None + + @log_function_entry_exit() + def __init__(self, task_object: EESSIDataAndSignatureObject): + """ + Initialize an EESSITaskDescription object. + + Args: + task_object: The EESSI data and signature object associated with this task + """ + self.task_object = task_object + self.metadata = {} + + self.task_object.download(mode=DownloadMode.CHECK_REMOTE) + + # verify signature and set initial state + self.signature_verified = self.task_object.verify_signature() + + # try to read metadata (will only succeed if signature is verified) + try: + self._read_metadata() + except RuntimeError: + # expected if signature is not verified yet + pass + + # check if the task file contains a task field and add that to self + if "task" in self.metadata: + self.task = self.metadata["task"] + else: + self.task = None + + # check if the task file contains a link2pr field and add that to source element + if "link2pr" in self.metadata: + self.source = self.metadata["link2pr"] + else: + self.source = None + + @log_function_entry_exit() + def get_contents(self) -> str: + """ + Get the contents of the task description / metadata file. + """ + return self.raw_contents + + @log_function_entry_exit() + def get_metadata_filename_components(self) -> Tuple[str, str, str, str, str, str]: + """ + Get the components of the metadata file name. + + An example of the metadata file name is: + eessi-2023.06-software-linux-x86_64-amd-zen2-1745557626.tar.gz.meta.txt + + The components are: + eessi: some prefix + VERSION: 2023.06 + COMPONENT: software + OS: linux + ARCHITECTURE: x86_64-amd-zen2 + TIMESTAMP: 1745557626 + SUFFIX: tar.gz.meta.txt + + The ARCHITECTURE component can include one to two hyphens. + The SUFFIX is the part after the first dot (no other components should include dots). + """ + # obtain file name from local file path using basename + file_name = Path(self.task_object.local_file_path).name + # split file_name into part before suffix and the suffix + # idea: split on last hyphen, then split on first dot + suffix = file_name.split("-")[-1].split(".", 1)[1] + file_name_without_suffix = file_name.strip(f".{suffix}") + # from file_name_without_suffix determine VERSION (2nd element), COMPONENT (3rd element), OS (4th element), + # ARCHITECTURE (5th to second last elements) and TIMESTAMP (last element) + components = file_name_without_suffix.split("-") + version = components[1] + component = components[2] + os = components[3] + architecture = "-".join(components[4:-1]) + timestamp = components[-1] + return version, component, os, architecture, timestamp, suffix + + @log_function_entry_exit() + def get_metadata_value(self, key: str) -> str: + """ + Get the value of a key from the task description / metadata file. + """ + # check that key is defined and has a length > 0 + if not key or len(key) == 0: + raise ValueError("get_metadata_value: key is not defined or has a length of 0") + + value = None + task = self.task + source = self.source + # check if key is in task or source + if task and key in task: + value = task[key] + log_message(LoggingScope.TASK_OPS, "INFO", + f"Value '{value}' for key '{key}' found in information from task metadata: {task}") + elif source and key in source: + value = source[key] + log_message(LoggingScope.TASK_OPS, "INFO", + f"Value '{value}' for key '{key}' found in information from source metadata: {source}") + else: + log_message(LoggingScope.TASK_OPS, "INFO", + f"Value for key '{key}' neither found in task metadata nor source metadata") + raise ValueError(f"Value for key '{key}' neither found in task metadata nor source metadata") + return value + + @log_function_entry_exit() + def get_pr_number(self) -> str: + """ + Get the PR number from the task description / metadata file. + """ + return self.get_metadata_value("pr") + + @log_function_entry_exit() + def get_repo_name(self) -> str: + """ + Get the repository name from the task description / metadata file. + """ + return self.get_metadata_value("repo") + + @log_function_entry_exit() + def get_task_file_name(self) -> str: + """ + Get the file name from the task description / metadata file. + """ + # get file name from remote file path using basename + file_name = Path(self.task_object.remote_file_path).name + return file_name + + @log_function_entry_exit() + def _read_metadata(self) -> None: + """ + Internal method to read and parse the metadata from the task description file. + Only reads metadata if the signature has been verified. + """ + if not self.signature_verified: + log_message(LoggingScope.ERROR, "ERROR", "Cannot read metadata: signature not verified for '%s'", + self.task_object.local_file_path) + raise RuntimeError("Cannot read metadata: signature not verified") + + try: + with open(self.task_object.local_file_path, "r") as file: + self.raw_contents = file.read() + self.metadata = json.loads(self.raw_contents) + log_message(LoggingScope.DEBUG, "DEBUG", "Successfully read metadata from '%s'", + self.task_object.local_file_path) + except json.JSONDecodeError as err: + log_message(LoggingScope.ERROR, "ERROR", "Failed to parse JSON in task description file '%s': '%s'", + self.task_object.local_file_path, str(err)) + raise + except Exception as err: + log_message(LoggingScope.ERROR, "ERROR", "Failed to read task description file '%s': '%s'", + self.task_object.local_file_path, str(err)) + raise + + @log_function_entry_exit() + def __str__(self) -> str: + """Return a string representation of the EESSITaskDescription object.""" + return f"EESSITaskDescription({self.task_object.local_file_path}, verified={self.signature_verified})" diff --git a/scripts/automated_ingestion/eessi_task_payload.py b/scripts/automated_ingestion/eessi_task_payload.py new file mode 100644 index 00000000..112fcfd1 --- /dev/null +++ b/scripts/automated_ingestion/eessi_task_payload.py @@ -0,0 +1,119 @@ +from dataclasses import dataclass +from pathlib import PurePosixPath +from typing import Dict + +import os +import tarfile + +from eessi_data_object import EESSIDataAndSignatureObject +from eessi_logging import log_function_entry_exit +from eessi_remote_storage_client import DownloadMode + + +@dataclass +class EESSITaskPayload: + """Class representing an EESSI task payload (tarball/artifact) and its signature.""" + + # The EESSI data and signature object associated with this payload + payload_object: EESSIDataAndSignatureObject + + # Whether the signature was successfully verified + signature_verified: bool = False + + # possibly at a later point in time, we will add inferred metadata here + # such as the prefix in a tarball, the main elements, or which software + # package it includes + + @log_function_entry_exit() + def __init__(self, payload_object: EESSIDataAndSignatureObject): + """ + Initialize an EESSITaskPayload object. + + Args: + payload_object: The EESSI data and signature object associated with this payload + """ + self.payload_object = payload_object + + # download the payload and its signature + self.payload_object.download(mode=DownloadMode.CHECK_REMOTE) + + # verify signature + self.signature_verified = self.payload_object.verify_signature() + + @log_function_entry_exit() + def analyse_contents(self, config: Dict) -> str: + """Analyse the contents of the payload and return a summary in a ready-to-use HTML format.""" + tar = tarfile.open(self.payload_object.local_file_path, "r") + members = tar.getmembers() + tar_num_members = len(members) + paths = sorted([m.path for m in members]) + + # reduce limit for full listing from 100 to 3 because the description can + # include 10s of tarballs and thus even 100 maybe too many; using a very + # small number can still be useful if there is only a very small number + # of files, say an architecture specific configuration file + if tar_num_members < 3: + tar_members_desc = "Full listing of the contents of the tarball:" + members_list = paths + + else: + tar_members_desc = "Summarized overview of the contents of the tarball:" + # determine prefix after filtering out '/init' subdirectory, + # to get actual prefix for specific CPU target (like '2023.06/software/linux/aarch64/neoverse_v1') + init_subdir = os.path.join("*", "init") + non_init_paths = sorted( + [path for path in paths if not any(parent.match(init_subdir) for parent in PurePosixPath(path).parents)] + ) + if non_init_paths: + prefix = os.path.commonprefix(non_init_paths) + else: + prefix = os.path.commonprefix(paths) + + # TODO: this only works for software tarballs, how to handle compat layer tarballs? + swdirs = [ # all directory names with the pattern: /software// + member.path + for member in members + if member.isdir() and PurePosixPath(member.path).match(os.path.join(prefix, 'software', '*', '*')) + ] + modfiles = [ # all filenames with the pattern: /modules///*.lua + member.path + for member in members + if member.isfile() + and PurePosixPath(member.path).match(os.path.join(prefix, 'modules', '*', '*', '*.lua')) + ] + reprod_dirs = [ + member.path + for member in members + if member.isdir() and PurePosixPath(member.path).match(os.path.join(prefix, 'reprod', '*', '*', '*')) + ] + other = [ # anything that is not in /software nor /modules nor /reprod + member.path + for member in members + if ( + not PurePosixPath(prefix).joinpath('software') in PurePosixPath(member.path).parents + and not PurePosixPath(prefix).joinpath('modules') in PurePosixPath(member.path).parents + and not PurePosixPath(prefix).joinpath('reprod') in PurePosixPath(member.path).parents + ) + # if not fnmatch.fnmatch(m.path, os.path.join(prefix, 'software', '*')) + # and not fnmatch.fnmatch(m.path, os.path.join(prefix, 'modules', '*')) + ] + members_list = sorted(swdirs + modfiles + reprod_dirs + other) + + # construct the overview + overview = config["github"]["task_summary_payload_overview_template"].format( + tar_num_members=tar_num_members, + bucket_url=self.payload_object.remote_client.get_bucket_url(), + remote_file_path=self.payload_object.remote_file_path, + tar_members_desc=tar_members_desc, + tar_members="\n".join(members_list) + ) + + # make sure that the overview does not exceed Github's maximum length (65536 characters) + if len(overview) > 60000: + overview = overview[:60000] + "\n\nWARNING: output exceeded the maximum length and was truncated!\n```" + return overview + + @log_function_entry_exit() + def __str__(self) -> str: + """Return a string representation of the EESSITaskPayload object.""" + return f"EESSITaskPayload({self.payload_object.local_file_path}, verified={self.signature_verified})" diff --git a/scripts/automated_ingestion/ingest_bundles.py b/scripts/automated_ingestion/ingest_bundles.py new file mode 100644 index 00000000..363d7060 --- /dev/null +++ b/scripts/automated_ingestion/ingest_bundles.py @@ -0,0 +1,263 @@ +#!/usr/bin/env python3 + +from eessi_data_object import EESSIDataAndSignatureObject +from eessi_task import EESSITask, EESSITaskState +from eessi_task_description import EESSITaskDescription +from eessi_s3_bucket import EESSIS3Bucket +from eessi_logging import error, log_function_entry_exit, log_message, LoggingScope, LOG_LEVELS, set_logging_scopes +from pid.decorator import pidfile # noqa: F401 +from pid import PidFileError + +import argparse +import configparser +import github +import json +import logging +import sys +from pathlib import Path +from typing import List + +REQUIRED_CONFIG = { + "secrets": ["aws_secret_access_key", "aws_access_key_id", "github_pat"], + "paths": ["download_dir", "ingestion_script", "metadata_file_extension"], + "aws": ["staging_buckets"], + "github": ["staging_repo", "failed_ingestion_issue_body", "pr_body"], +} + + +@log_function_entry_exit() +def parse_config(path): + """Parse the configuration file.""" + config = configparser.ConfigParser() + try: + config.read(path) + except Exception as err: + error(f"Unable to read configuration file '{path}'!\nException: '{err}'") + + # check if all required configuration parameters/sections can be found + for section in REQUIRED_CONFIG.keys(): + if section not in config: + error(f"Missing section '{section}' in configuration file '{path}'.") + for item in REQUIRED_CONFIG[section]: + if item not in config[section]: + error(f"Missing configuration item '{item}' in section '{section}' of configuration file '{path}'.") + + return config + + +@log_function_entry_exit() +def parse_args(): + """Parse the command-line arguments.""" + parser = argparse.ArgumentParser() + + # logging options + logging_group = parser.add_argument_group("Logging options") + logging_group.add_argument("--log-file", + help="Path to log file (overrides config file setting)") + logging_group.add_argument("--console-level", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help="Logging level for console output (overrides config file setting)") + logging_group.add_argument("--file-level", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help="Logging level for file output (overrides config file setting)") + logging_group.add_argument("--quiet", + action="store_true", + help="Suppress console output (overrides all other console settings)") + logging_group.add_argument("--log-scopes", + help="Comma-separated list of logging scopes using +/- syntax. " + "Examples: '+FUNC_ENTRY_EXIT' (enable only function entry/exit), " + "'+ALL,-FUNC_ENTRY_EXIT' (enable all except function entry/exit), " + "'+FUNC_ENTRY_EXIT,-EXAMPLE_SCOPE' (enable function entry/exit but disable example)") + + # existing arguments + parser.add_argument("-c", "--config", type=str, help="path to configuration file", + default="ingest_bundles.cfg", dest="config") + parser.add_argument("-d", "--debug", help="enable debug mode", action="store_true", dest="debug") + parser.add_argument("-l", "--list", help="only list available tasks", action="store_true", dest="list_only") + parser.add_argument("--extensions", help="comma-separated list of extensions to process (default: .task)", + nargs="?", const=".task", default=False) + + return parser.parse_args() + + +@log_function_entry_exit() +def setup_logging(config: configparser.ConfigParser, args: argparse.Namespace) -> logging.Logger: + """ + Configure logging based on configuration file and command line arguments. + Command line arguments take precedence over config file settings. + + Args: + config: Configuration parser + args: Parsed command line arguments + + Returns: + Logger instance + """ + # get settings from config file + log_file = config["logging"].get("log_file") + config_console_level = LOG_LEVELS.get(config["logging"].get("console_level", "INFO").upper(), logging.INFO) + config_file_level = LOG_LEVELS.get(config["logging"].get("file_level", "DEBUG").upper(), logging.DEBUG) + + # override with command line arguments if provided + log_file = args.log_file if args.log_file else log_file + console_level = getattr(logging, args.console_level) if args.console_level else config_console_level + file_level = getattr(logging, args.file_level) if args.file_level else config_file_level + + # debug mode overrides console level + if args.debug: + console_level = logging.DEBUG + + # set up logging scopes + if args.log_scopes: + set_logging_scopes(args.log_scopes) + log_message(LoggingScope.DEBUG, "DEBUG", "Enabled logging scopes: '%s'", args.log_scopes) + + # create logger + logger = logging.getLogger() + logger.setLevel(logging.DEBUG) # set root logger to lowest level + + # create formatters + console_formatter = logging.Formatter("%(levelname)-8s: %(message)s") + file_formatter = logging.Formatter("%(asctime)s - %(levelname)-8s: %(message)s") + + # console handler (only if not quiet) + if not args.quiet: + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(console_level) + console_handler.setFormatter(console_formatter) + logger.addHandler(console_handler) + + # file handler (if log file is specified) + if log_file: + # ensure log directory exists + log_path = Path(log_file) + log_path.parent.mkdir(parents=True, exist_ok=True) + + file_handler = logging.FileHandler(log_file) + file_handler.setLevel(file_level) + file_handler.setFormatter(file_formatter) + logger.addHandler(file_handler) + + return logger + + +@pidfile("shared_lock.pid") # noqa: F401 +@log_function_entry_exit() +def main(): + """Main function.""" + args = parse_args() + config = parse_config(args.config) + _ = setup_logging(config, args) # noqa: F841 + + # TODO: check configuration: secrets, paths, permissions on dirs, etc + extensions = args.extensions.split(",") + gh_pat = config["secrets"]["github_pat"] + gh_staging_repo = github.Github(gh_pat).get_repo(config["github"]["staging_repo"]) + + buckets = json.loads(config["aws"]["staging_buckets"]) + for bucket, cvmfs_repo in buckets.items(): + # create our custom S3 bucket for this bucket + s3_bucket = EESSIS3Bucket(config, bucket) + + tasks = find_deployment_tasks(s3_bucket, extensions) + if args.list_only: + log_message(LoggingScope.GROUP_OPS, "INFO", "#tasks: %d", len(tasks)) + for num, task in enumerate(tasks): + log_message(LoggingScope.GROUP_OPS, "INFO", "[%s] %d: '%s'", bucket, num, task) + else: + # process each task file + for task_path in tasks: + log_message(LoggingScope.GROUP_OPS, "INFO", "Processing task: '%s'", task_path) + + try: + # create EESSITask for the task file + try: + task = EESSITask( + EESSITaskDescription(EESSIDataAndSignatureObject(config, task_path, s3_bucket)), + config, cvmfs_repo, gh_staging_repo + ) + + except Exception as err: + log_message(LoggingScope.ERROR, "ERROR", "Failed to create EESSITask for task '%s': '%s'", + task_path, str(err)) + continue + + log_message(LoggingScope.GROUP_OPS, "INFO", "Created EESSITask: '%s'", task) + + previous_state = None + current_state = task.determine_state() + log_message(LoggingScope.GROUP_OPS, "INFO", "Task '%s' is in state '%s'", + task_path, current_state.name) + while (current_state is not None and + current_state != EESSITaskState.DONE and + previous_state != current_state): + previous_state = current_state + log_message(LoggingScope.GROUP_OPS, "INFO", + "Task '%s': BEFORE handle(): previous state = '%s', current state = '%s'", + task_path, previous_state.name, current_state.name) + current_state = task.handle() + log_message(LoggingScope.GROUP_OPS, "INFO", + "Task '%s': AFTER handle(): previous state = '%s', current state = '%s'", + task_path, previous_state.name, current_state.name) + + except Exception as err: + log_message(LoggingScope.ERROR, "ERROR", "Failed to process task '%s': '%s'", task_path, str(err)) + continue + + +@log_function_entry_exit() +def find_deployment_tasks(s3_bucket: EESSIS3Bucket, extensions: List[str] = None) -> List[str]: + """ + Return a list of all task files in an S3 bucket with the given extensions, + but only if a corresponding payload file exists (same name without extension). + + Args: + s3_bucket: EESSIS3Bucket instance + extensions: List of file extensions to look for (default: ['.task']) + + Returns: + List of task filenames found in the bucket that have a corresponding payload + """ + if extensions is None: + extensions = [".task"] + + files = [] + continuation_token = None + + while True: + # list objects with pagination + if continuation_token: + response = s3_bucket.list_objects_v2( + ContinuationToken=continuation_token + ) + else: + response = s3_bucket.list_objects_v2() + + # add files from this page + files.extend([obj["Key"] for obj in response.get("Contents", [])]) + + # check if there are more pages + if response.get("IsTruncated"): + continuation_token = response.get("NextContinuationToken") + else: + break + + # create a set of all files for faster lookup + file_set = set(files) + + # return only task files that have a corresponding payload + result = [] + for file in files: + for ext in extensions: + if file.endswith(ext) and file[:-len(ext)] in file_set: + result.append(file) + break # found a matching extension, no need to check other extensions + + return result + + +if __name__ == "__main__": + try: + main() + except PidFileError as err: + error(f"Another instance of this script is already running! Error: '{err}'") diff --git a/scripts/automated_ingestion/pytest.sh b/scripts/automated_ingestion/pytest.sh new file mode 100755 index 00000000..f8b4e170 --- /dev/null +++ b/scripts/automated_ingestion/pytest.sh @@ -0,0 +1,10 @@ +#!/bin/bash +# +# This file is part of the EESSI filesystem layer, +# see https://github.com/EESSI/filesystem-layer +# +# author: Thomas Roeblitz (@trz42) +# +# license: GPLv2 +# +PYTHONPATH=$PWD:$PYTHONPATH pytest --capture=no "$@" \ No newline at end of file diff --git a/scripts/automated_ingestion/unit_tests/__init__.py b/scripts/automated_ingestion/unit_tests/__init__.py new file mode 100644 index 00000000..467d5dfe --- /dev/null +++ b/scripts/automated_ingestion/unit_tests/__init__.py @@ -0,0 +1 @@ +# This file makes the unit_tests directory a Python package diff --git a/scripts/automated_ingestion/unit_tests/test_basic.py b/scripts/automated_ingestion/unit_tests/test_basic.py new file mode 100644 index 00000000..7e382cbd --- /dev/null +++ b/scripts/automated_ingestion/unit_tests/test_basic.py @@ -0,0 +1,27 @@ +""" +Basic test file to prevent pytest from failing with exit code 5 when no tests are found. + +This file is part of the EESSI filesystem layer, +see https://github.com/EESSI/filesystem-layer + +author: Thomas Roeblitz (@trz42) + +license: GPLv2 +""" + +import pytest + + +def test_basic_placeholder(): + """Basic placeholder test that always passes.""" + assert True + + +def test_import_modules(): + """Test that we can import the main modules without errors.""" + try: + import eessi_logging + # Verify the modules were imported successfully + assert eessi_logging is not None + except ImportError as err: + pytest.skip(f"Module import failed: {err}")