diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..852c375b --- /dev/null +++ b/.flake8 @@ -0,0 +1,20 @@ +# This file is part of the EESSI filesystem layer, +# see https://github.com/EESSI/filesystem-layer +# +# author: Thomas Roeblitz (@trz42) +# +# license: GPLv2 +# + +[flake8] +exclude = + scripts/check-stratum-servers.py, + scripts/automated_ingestion/automated_ingestion.py, + scripts/automated_ingestion/eessitarball.py, + scripts/automated_ingestion/utils.py + +# ignore "Black would make changes" produced by flake8-black +# see also https://github.com/houndci/hound/issues/1769 +extend-ignore = BLK100 + +max-line-length = 120 diff --git a/.github/workflows/test-ingest-python-code.yml b/.github/workflows/test-ingest-python-code.yml new file mode 100644 index 00000000..9e341783 --- /dev/null +++ b/.github/workflows/test-ingest-python-code.yml @@ -0,0 +1,49 @@ +# This file is part of the EESSI filesystem layer, +# see https://github.com/EESSI/filesystem-layer +# +# author: Thomas Roeblitz (@trz42) +# +# license: GPLv2 +# + +name: Run tests +on: [push, pull_request] +# Declare default permissions as read only. +permissions: read-all +jobs: + test: + runs-on: ubuntu-24.04 + strategy: + matrix: + # for now, only test with Python 3.9+ (since we're testing in Ubuntu 24.04) + #python: [3.6, 3.7, 3.8, 3.9, '3.10', '3.11'] + python: ['3.9', '3.10', '3.11'] + fail-fast: false + steps: + - name: checkout + uses: actions/checkout@93ea575cb5d8a053eaa0ac8fa3b40d7e05a33cc8 # v3.1.0 + + - name: set up Python + uses: actions/setup-python@13ae5bb136fac2878aff31522b9efb785519f984 # v4.3.0 + with: + python-version: ${{matrix.python}} + + - name: Install required Python packages + pytest + flake8 + run: | + python -m pip install --upgrade pip + python -m pip install -r scripts/automated_ingestion/requirements.txt + python -m pip install pytest + python -m pip install --upgrade flake8 + + - name: Run test suite (without coverage) + run: | + ./scripts/automated_ingestion/pytest.sh scripts/automated_ingestion --verbose + + - name: Run test suite (with coverage) + run: | + python -m pip install pytest-cov + ./scripts/automated_ingestion/pytest.sh scripts/automated_ingestion -q --cov=scripts/automated_ingestion/eessi_logging.py + + - name: Run flake8 to verify PEP8-compliance of Python code + run: | + flake8 scripts/automated_ingestion --exclude=scripts/automated_ingestion/automated_ingestion.py,scripts/automated_ingestion/eessitarball.py \ No newline at end of file diff --git a/.gitignore b/.gitignore index 39af2bac..893c00e4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ build hosts +.coverage +**/__pycache__ diff --git a/scripts/automated_ingestion/.coveragerc b/scripts/automated_ingestion/.coveragerc new file mode 100644 index 00000000..2941a1ed --- /dev/null +++ b/scripts/automated_ingestion/.coveragerc @@ -0,0 +1,6 @@ +[run] +omit = + scripts/automated_ingestion/automated_ingestion.py + scripts/automated_ingestion/eessitarball.py + scripts/automated_ingestion/utils.py + scripts/automated_ingestion/unit_tests/*.py diff --git a/scripts/automated_ingestion/automated_ingestion.py b/scripts/automated_ingestion/automated_ingestion.py index 92dac552..7a7f9dc8 100755 --- a/scripts/automated_ingestion/automated_ingestion.py +++ b/scripts/automated_ingestion/automated_ingestion.py @@ -81,7 +81,7 @@ def parse_args(): return args -@pid.decorator.pidfile('automated_ingestion.pid') +@pidfile('shared_lock.pid') # noqa: F401 def main(): """Main function.""" args = parse_args() diff --git a/scripts/automated_ingestion/eessi_data_object.py b/scripts/automated_ingestion/eessi_data_object.py new file mode 100644 index 00000000..4989f24c --- /dev/null +++ b/scripts/automated_ingestion/eessi_data_object.py @@ -0,0 +1,344 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +import configparser +import subprocess + +from eessi_logging import log_function_entry_exit, log_message, LoggingScope +from eessi_remote_storage_client import DownloadMode, EESSIRemoteStorageClient + + +@dataclass +class EESSIDataAndSignatureObject: + """Class representing an EESSI data file and its signature in remote storage and locally.""" + + # configuration + config: configparser.ConfigParser + + # remote paths + remote_file_path: str # path to data file in remote storage + remote_sig_path: str # path to signature file in remote storage + + # local paths + local_file_path: Path # path to local data file + local_sig_path: Path # path to local signature file + + # remote storage client + remote_client: EESSIRemoteStorageClient + + @log_function_entry_exit() + def __init__( + self, + config: configparser.ConfigParser, + remote_file_path: str, + remote_client: EESSIRemoteStorageClient, + ): + """ + Initialize an EESSI data and signature object handler. + + Args: + config: configuration object containing remote storage and local directory information + remote_file_path: path to data file in remote storage + remote_client: remote storage client implementing the EESSIRemoteStorageClient protocol + """ + self.config = config + self.remote_file_path = remote_file_path + sig_ext = config["signatures"]["signature_file_extension"] + self.remote_sig_path = remote_file_path + sig_ext + + # set up local paths + local_dir = Path(config["paths"]["download_dir"]) + # use the full remote path structure, removing any leading slashes + remote_path = remote_file_path.lstrip("/") + self.local_file_path = local_dir.joinpath(remote_path) + self.local_sig_path = local_dir.joinpath(remote_path + sig_ext) + self.remote_client = remote_client + + log_message(LoggingScope.DEBUG, "DEBUG", "Initialized EESSIDataAndSignatureObject for '%s'", remote_file_path) + log_message(LoggingScope.DEBUG, "DEBUG", "Local file path: '%s'", self.local_file_path) + log_message(LoggingScope.DEBUG, "DEBUG", "Local signature path: '%s'", self.local_sig_path) + + @log_function_entry_exit() + def _get_etag_file_path(self, local_path: Path) -> Path: + """Get the path to the .etag file for a given local file.""" + return local_path.with_suffix(".etag") + + @log_function_entry_exit() + def _get_local_etag(self, local_path: Path) -> Optional[str]: + """Get the ETag of a local file from its .etag file.""" + etag_path = self._get_etag_file_path(local_path) + if etag_path.exists(): + try: + with open(etag_path, "r") as f: + return f.read().strip() + except Exception as err: + log_message(LoggingScope.DEBUG, "WARNING", "Failed to read ETag file '%s': '%s'", etag_path, str(err)) + return None + return None + + @log_function_entry_exit() + def get_etags(self) -> tuple[Optional[str], Optional[str]]: + """ + Get the ETags of both the data file and its signature. + + Returns: + Tuple containing (data_file_etag, signature_file_etag) + """ + return ( + self._get_local_etag(self.local_file_path), + self._get_local_etag(self.local_sig_path) + ) + + @log_function_entry_exit() + def verify_signature(self) -> bool: + """ + Verify the signature of the data file using the corresponding signature file. + + Returns: + bool: True if the signature is valid or if signatures are not required, False otherwise + """ + # check if signature file exists + if not self.local_sig_path.exists(): + log_message(LoggingScope.VERIFICATION, "WARNING", "Signature file '%s' is missing", + self.local_sig_path) + + # if signatures are required, return failure + if self.config["signatures"].getboolean("signatures_required", True): + log_message(LoggingScope.ERROR, "ERROR", "Signature file '%s' is missing and signatures are required", + self.local_sig_path) + return False + else: + log_message(LoggingScope.VERIFICATION, "INFO", + "Signature file '%s' is missing, but signatures are not required", + self.local_sig_path) + return True + + # if signatures are provided, we should always verify them, regardless of the signatures_required setting + verify_runenv = self.config["signatures"]["signature_verification_runenv"].split() + verify_script = self.config["signatures"]["signature_verification_script"] + allowed_signers_file = self.config["signatures"]["allowed_signers_file"] + + # check if verification tools exist + if not Path(verify_script).exists(): + log_message(LoggingScope.ERROR, "ERROR", + "Unable to verify signature: verification script '%s' does not exist", verify_script) + return False + + if not Path(allowed_signers_file).exists(): + log_message(LoggingScope.ERROR, "ERROR", + "Unable to verify signature: allowed signers file '%s' does not exist", allowed_signers_file) + return False + + # run the verification command with named parameters + cmd = verify_runenv + [ + verify_script, + "--verify", + "--allowed-signers-file", allowed_signers_file, + "--file", str(self.local_file_path), + "--signature-file", str(self.local_sig_path) + ] + log_message(LoggingScope.VERIFICATION, "INFO", "Running command: '%s'", " ".join(cmd)) + + try: + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode == 0: + log_message(LoggingScope.VERIFICATION, "INFO", + "Successfully verified signature for '%s'", self.local_file_path) + log_message(LoggingScope.VERIFICATION, "DEBUG", " stdout: '%s'", result.stdout) + log_message(LoggingScope.VERIFICATION, "DEBUG", " stderr: '%s'", result.stderr) + return True + else: + log_message(LoggingScope.ERROR, "ERROR", + "Signature verification failed for '%s'", self.local_file_path) + log_message(LoggingScope.ERROR, "ERROR", " stdout: '%s'", result.stdout) + log_message(LoggingScope.ERROR, "ERROR", " stderr: '%s'", result.stderr) + return False + except Exception as err: + log_message(LoggingScope.ERROR, "ERROR", + "Error during signature verification for '%s': '%s'", + self.local_file_path, str(err)) + return False + + @log_function_entry_exit() + def download(self, mode: DownloadMode = DownloadMode.CHECK_REMOTE) -> bool: + """ + Download data file and signature based on the specified mode. + + Args: + mode: Download mode to use + + Returns: + True if files were downloaded, False otherwise + """ + # if mode is FORCE, we always download regardless of local or remote state + if mode == DownloadMode.FORCE: + should_download = True + log_message(LoggingScope.DOWNLOAD, "INFO", "Forcing download of '%s'", self.remote_file_path) + # for CHECK_REMOTE mode, check if we can optimize + elif mode == DownloadMode.CHECK_REMOTE: + # optimization: check if local files exist first + local_files_exist = ( + self.local_file_path.exists() and + self.local_sig_path.exists() + ) + + # if files don't exist locally, we can skip ETag checks + if not local_files_exist: + log_message(LoggingScope.DOWNLOAD, "INFO", + "Local files missing, skipping ETag checks and downloading '%s'", + self.remote_file_path) + should_download = True + else: + # first check if we have local ETags + try: + local_file_etag = self._get_local_etag(self.local_file_path) + local_sig_etag = self._get_local_etag(self.local_sig_path) + + if local_file_etag: + log_message(LoggingScope.DOWNLOAD, "DEBUG", "Local file ETag: '%s'", local_file_etag) + else: + log_message(LoggingScope.DOWNLOAD, "DEBUG", "No local file ETag found") + if local_sig_etag: + log_message(LoggingScope.DOWNLOAD, "DEBUG", "Local signature ETag: '%s'", local_sig_etag) + else: + log_message(LoggingScope.DOWNLOAD, "DEBUG", "No local signature ETag found") + + # if we don't have local ETags, we need to download + if not local_file_etag or not local_sig_etag: + should_download = True + log_message(LoggingScope.DOWNLOAD, "INFO", "Missing local ETags, downloading '%s'", + self.remote_file_path) + else: + # get remote ETags and compare + remote_file_etag = self.remote_client.get_metadata(self.remote_file_path)["ETag"] + remote_sig_etag = self.remote_client.get_metadata(self.remote_sig_path)["ETag"] + log_message(LoggingScope.DOWNLOAD, "DEBUG", "Remote file ETag: '%s'", remote_file_etag) + log_message(LoggingScope.DOWNLOAD, "DEBUG", "Remote signature ETag: '%s'", remote_sig_etag) + + should_download = ( + remote_file_etag != local_file_etag or + remote_sig_etag != local_sig_etag + ) + if should_download: + if remote_file_etag != local_file_etag: + log_message(LoggingScope.DOWNLOAD, "INFO", "File ETag changed from '%s' to '%s'", + local_file_etag, remote_file_etag) + if remote_sig_etag != local_sig_etag: + log_message(LoggingScope.DOWNLOAD, "INFO", "Signature ETag changed from '%s' to '%s'", + local_sig_etag, remote_sig_etag) + log_message(LoggingScope.DOWNLOAD, "INFO", "Remote files have changed, downloading '%s'", + self.remote_file_path) + else: + log_message(LoggingScope.DOWNLOAD, "INFO", + "Remote files unchanged, skipping download of '%s'", + self.remote_file_path) + except Exception as etag_err: + # if we get any error with ETags, we'll just download the files + log_message(LoggingScope.DOWNLOAD, "DEBUG", "Error handling ETags, will download files: '%s'", + str(etag_err)) + should_download = True + else: # check_local + should_download = ( + not self.local_file_path.exists() or + not self.local_sig_path.exists() + ) + if should_download: + if not self.local_file_path.exists(): + log_message(LoggingScope.DOWNLOAD, "INFO", "Local file missing: '%s'", self.local_file_path) + if not self.local_sig_path.exists(): + log_message(LoggingScope.DOWNLOAD, "INFO", "Local signature missing: '%s'", self.local_sig_path) + log_message(LoggingScope.DOWNLOAD, "INFO", "Local files missing, downloading '%s'", + self.remote_file_path) + else: + log_message(LoggingScope.DOWNLOAD, "INFO", "Local files exist, skipping download of '%s'", + self.remote_file_path) + + if not should_download: + return False + + # ensure local directory exists + self.local_file_path.parent.mkdir(parents=True, exist_ok=True) + + # download files + try: + # download the main file first + self.remote_client.download(self.remote_file_path, str(self.local_file_path)) + + # get and log the ETag of the downloaded file + try: + file_etag = self._get_local_etag(self.local_file_path) + log_message(LoggingScope.DOWNLOAD, "DEBUG", "Downloaded '%s' with ETag: '%s'", + self.remote_file_path, file_etag) + except Exception as etag_err: + log_message(LoggingScope.DOWNLOAD, "DEBUG", "Error getting ETag for '%s': '%s'", + self.remote_file_path, str(etag_err)) + + # try to download the signature file + try: + self.remote_client.download(self.remote_sig_path, str(self.local_sig_path)) + try: + sig_etag = self._get_local_etag(self.local_sig_path) + log_message(LoggingScope.DOWNLOAD, "DEBUG", "Downloaded '%s' with ETag: '%s'", + self.remote_sig_path, sig_etag) + except Exception as etag_err: + log_message(LoggingScope.DOWNLOAD, "DEBUG", "Error getting ETag for '%s': '%s'", + self.remote_sig_path, str(etag_err)) + log_message(LoggingScope.DOWNLOAD, "INFO", "Successfully downloaded '%s' and its signature", + self.remote_file_path) + except Exception as sig_err: + # check if signatures are required + if self.config["signatures"].getboolean("signatures_required", True): + # if signatures are required, clean up everything since we can't proceed + if self.local_file_path.exists(): + self.local_file_path.unlink() + # clean up etag files regardless of whether their data files exist + file_etag_path = self._get_etag_file_path(self.local_file_path) + if file_etag_path.exists(): + file_etag_path.unlink() + sig_etag_path = self._get_etag_file_path(self.local_sig_path) + if sig_etag_path.exists(): + sig_etag_path.unlink() + log_message(LoggingScope.ERROR, "ERROR", "Failed to download required signature for '%s': '%s'", + self.remote_file_path, str(sig_err)) + raise + else: + # if signatures are optional, just clean up any partial signature files + if self.local_sig_path.exists(): + self.local_sig_path.unlink() + sig_etag_path = self._get_etag_file_path(self.local_sig_path) + if sig_etag_path.exists(): + sig_etag_path.unlink() + log_message(LoggingScope.DOWNLOAD, "WARNING", + "Failed to download optional signature for '%s': '%s'", + self.remote_file_path, str(sig_err)) + log_message(LoggingScope.DOWNLOAD, "INFO", + "Successfully downloaded '%s' (signature optional)", + self.remote_file_path) + + return True + except Exception as err: + # this catch block is only for errors in the main file download + # clean up partially downloaded files and their etags + if self.local_file_path.exists(): + self.local_file_path.unlink() + if self.local_sig_path.exists(): + self.local_sig_path.unlink() + # clean up etag files regardless of whether their data files exist + file_etag_path = self._get_etag_file_path(self.local_file_path) + if file_etag_path.exists(): + file_etag_path.unlink() + sig_etag_path = self._get_etag_file_path(self.local_sig_path) + if sig_etag_path.exists(): + sig_etag_path.unlink() + log_message(LoggingScope.ERROR, "ERROR", "Failed to download '%s': '%s'", self.remote_file_path, str(err)) + raise + + @log_function_entry_exit() + def get_url(self) -> str: + """Get the URL of the data file.""" + return f"https://{self.remote_client.bucket}.s3.amazonaws.com/{self.remote_file_path}" + + def __str__(self) -> str: + """Return a string representation of the EESSI data and signature object.""" + return f"EESSIDataAndSignatureObject({self.remote_file_path})" diff --git a/scripts/automated_ingestion/eessi_logging.py b/scripts/automated_ingestion/eessi_logging.py new file mode 100644 index 00000000..f94c8495 --- /dev/null +++ b/scripts/automated_ingestion/eessi_logging.py @@ -0,0 +1,262 @@ +import functools +import inspect +import logging +import os +import sys +import time + +from enum import IntFlag, auto +from typing import Callable, Union + + +LOG_LEVELS = { + 'DEBUG': logging.DEBUG, + 'INFO': logging.INFO, + 'WARNING': logging.WARNING, + 'ERROR': logging.ERROR, + 'CRITICAL': logging.CRITICAL +} + + +class LoggingScope(IntFlag): + """Enumeration of different logging scopes.""" + NONE = 0 + FUNC_ENTRY_EXIT = auto() # Function entry/exit logging + DOWNLOAD = auto() # Logging related to file downloads + VERIFICATION = auto() # Logging related to signature and checksum verification + STATE_OPS = auto() # Logging related to tarball state operations + GITHUB_OPS = auto() # Logging related to GitHub operations (PRs, issues, etc.) + GROUP_OPS = auto() # Logging related to tarball group operations + TASK_OPS = auto() # Logging related to task operations + TASK_OPS_DETAILS = auto() # Logging related to task operations (detailed) + ERROR = auto() # Error logging (separate from other scopes for easier filtering) + DEBUG = auto() # Debug-level logging (separate from other scopes for easier filtering) + ALL = (FUNC_ENTRY_EXIT | DOWNLOAD | VERIFICATION | STATE_OPS | + GITHUB_OPS | GROUP_OPS | TASK_OPS | TASK_OPS_DETAILS | ERROR | DEBUG) + + +# Global setting for logging scopes +ENABLED_LOGGING_SCOPES = LoggingScope.NONE + + +# Global variable to track call stack depth +_call_stack_depth = 0 + + +def error(msg, code=1): + """Print an error and exit.""" + log_message(LoggingScope.ERROR, 'ERROR', msg) + sys.exit(code) + + +def is_logging_scope_enabled(scope: LoggingScope) -> bool: + """Check if a specific logging scope is enabled.""" + return bool(ENABLED_LOGGING_SCOPES & scope) + + +def log_function_entry_exit(logger: logging.Logger = None) -> Callable: + """ + Decorator that logs function entry and exit with timing information. + Only logs if the FUNC_ENTRY_EXIT scope is enabled. + + Args: + logger: Optional logger instance. If not provided, uses the module's logger. + """ + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + global _call_stack_depth + + if not is_logging_scope_enabled(LoggingScope.FUNC_ENTRY_EXIT): + return func(*args, **kwargs) + + if logger is None: + log = logging.getLogger(func.__module__) + else: + log = logger + + # Get context information if available + context = "" + if len(args) > 0 and hasattr(args[0], 'object'): + # For EessiTarball methods, show the tarball name and state + tarball = args[0] + filename = os.path.basename(tarball.object) + + # Format filename to show important parts + if len(filename) > 30: + parts = filename.split('-') + if len(parts) >= 6: # Ensure we have all required parts + # Get version, component, last part of architecture, and epoch + version = parts[1] + component = parts[2] + arch_last = parts[-2].split('-')[-1] # Last part of architecture + epoch = parts[-1] # includes file extension + filename = f"{version}-{component}-{arch_last}-{epoch}" + else: + # Fallback to simple truncation if format doesn't match + filename = f"{filename[:15]}...{filename[-12:]}" + + context = f" [{filename}" + if hasattr(tarball, 'state'): + context += f" in {tarball.state}" + context += "]" + + # Create indentation based on call stack depth + indent = " " * _call_stack_depth + + # Get file name and line number where the function is defined + file_name = os.path.basename(inspect.getsourcefile(func)) + source_lines, start_line = inspect.getsourcelines(func) + # Find the line with the actual function definition + def_line = next(i for i, line in enumerate(source_lines) if line.strip().startswith('def ')) + def_line_no = start_line + def_line + # Find the last non-empty line of the function + last_line = next(i for i, line in enumerate(reversed(source_lines)) if line.strip()) + last_line_no = start_line + len(source_lines) - 1 - last_line + + start_time = time.time() + log.info(f"{indent}[FUNC_ENTRY_EXIT] Entering {func.__name__} at {file_name}:{def_line_no}{context}") + _call_stack_depth += 1 + try: + result = func(*args, **kwargs) + _call_stack_depth -= 1 + end_time = time.time() + # For normal returns, show the last line of the function + log.info(f"{indent}[FUNC_ENTRY_EXIT] Leaving {func.__name__} at {file_name}:{last_line_no}" + f"{context} (took {end_time - start_time:.2f}s)") + return result + except Exception as err: + _call_stack_depth -= 1 + end_time = time.time() + # For exceptions, try to get the line number from the exception + try: + exc_line_no = err.__traceback__.tb_lineno + except AttributeError: + exc_line_no = last_line_no + log.info(f"{indent}[FUNC_ENTRY_EXIT] Leaving {func.__name__} at {file_name}:{exc_line_no}" + f"{context} with exception (took {end_time - start_time:.2f}s)") + raise err + return wrapper + return decorator + + +def log_message(scope, level, msg, *args, logger=None, **kwargs): + """ + Log a message if either: + 1. The specified scope is enabled, OR + 2. The current log level is equal to or higher than the specified level + + Args: + scope: LoggingScope value indicating which scope this logging belongs to + level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) + msg: Message to log + logger: Optional logger instance. If not provided, uses the root logger. + *args, **kwargs: Additional arguments to pass to the logging function + """ + log = logger or logging.getLogger() + log_level = getattr(logging, level.upper()) + + # Check if either condition is met + if not (is_logging_scope_enabled(scope) or log_level >= log.getEffectiveLevel()): + return + + # Create indentation based on call stack depth + indent = " " * _call_stack_depth + # Add scope to the message + scoped_msg = f"[{scope.name}] {msg}" + indented_msg = f"{indent}{scoped_msg}" + + # If scope is enabled, use the temporary handler + if is_logging_scope_enabled(scope): + # Save original handlers + original_handlers = list(log.handlers) + + # Create a temporary handler that accepts all levels + temp_handler = logging.StreamHandler(sys.stdout) + temp_handler.setLevel(logging.DEBUG) + temp_handler.setFormatter(logging.Formatter('%(levelname)-8s: %(message)s')) + + try: + # Remove existing handlers temporarily + for handler in original_handlers: + log.removeHandler(handler) + + # Add temporary handler + log.addHandler(temp_handler) + + # Log the message + log_func = getattr(log, level.lower()) + log_func(indented_msg, *args, **kwargs) + finally: + log.removeHandler(temp_handler) + # Restore original handlers + for handler in original_handlers: + if handler not in log.handlers: + log.addHandler(handler) + # Only use normal logging if scope is not enabled AND level is high enough + elif not is_logging_scope_enabled(scope) and log_level >= log.getEffectiveLevel(): + # Use normal logging with level check + log_func = getattr(log, level.lower()) + log_func(indented_msg, *args, **kwargs) + + +def set_logging_scopes(scopes: Union[LoggingScope, str, list[str]]) -> None: + """ + Set the enabled logging scopes. + + Args: + scopes: Can be + - A LoggingScope value + - A string with comma-separated values using +/- syntax: + - "+SCOPE" to enable a scope + - "-SCOPE" to disable a scope + - "ALL" or "+ALL" to enable all scopes + - "-ALL" to disable all scopes + Examples: + "+FUNC_ENTRY_EXIT" # Enable only function entry/exit + "+FUNC_ENTRY_EXIT,-EXAMPLE_SCOPE" # Enable function entry/exit but disable example + "+ALL,-FUNC_ENTRY_EXIT" # Enable all scopes except function entry/exit + """ + global ENABLED_LOGGING_SCOPES + + if isinstance(scopes, LoggingScope): + ENABLED_LOGGING_SCOPES = scopes + return + + if isinstance(scopes, str): + # Start with no scopes enabled + ENABLED_LOGGING_SCOPES = LoggingScope.NONE + + # Split into individual scope specifications + scope_specs = [s.strip() for s in scopes.split(",")] + + for spec in scope_specs: + if not spec: + continue + + # Check for ALL special case + if spec.upper() in ["ALL", "+ALL"]: + ENABLED_LOGGING_SCOPES = LoggingScope.ALL + continue + elif spec.upper() == "-ALL": + ENABLED_LOGGING_SCOPES = LoggingScope.NONE + continue + + # Parse scope name and operation + operation = spec[0] + scope_name = spec[1:].strip().upper() + + try: + scope_enum = LoggingScope[scope_name] + if operation == '+': + ENABLED_LOGGING_SCOPES |= scope_enum + elif operation == '-': + ENABLED_LOGGING_SCOPES &= ~scope_enum + else: + logging.warning(f"Invalid operation '{operation}' in scope specification: {spec}") + except KeyError: + logging.warning(f"Unknown logging scope: {scope_name}") + + elif isinstance(scopes, list): + # Convert list to comma-separated string and process + set_logging_scopes(",".join(scopes)) diff --git a/scripts/automated_ingestion/eessi_remote_storage_client.py b/scripts/automated_ingestion/eessi_remote_storage_client.py new file mode 100644 index 00000000..9f83d721 --- /dev/null +++ b/scripts/automated_ingestion/eessi_remote_storage_client.py @@ -0,0 +1,34 @@ +from enum import Enum +from typing import Protocol, runtime_checkable + + +class DownloadMode(Enum): + """Enum defining different modes for downloading files.""" + FORCE = 'force' # Always download and overwrite + CHECK_REMOTE = 'check-remote' # Download if remote files have changed + CHECK_LOCAL = 'check-local' # Download if files don't exist locally (default) + + +@runtime_checkable +class EESSIRemoteStorageClient(Protocol): + """Protocol defining the interface for remote storage clients.""" + + def get_metadata(self, remote_path: str) -> dict: + """Get metadata about a remote object. + + Args: + remote_path: Path to the object in remote storage + + Returns: + Dictionary containing object metadata, including 'ETag' key + """ + ... + + def download(self, remote_path: str, local_path: str) -> None: + """Download a remote file to a local location. + + Args: + remote_path: Path to the object in remote storage + local_path: Local path where to save the file + """ + ... diff --git a/scripts/automated_ingestion/eessi_s3_bucket.py b/scripts/automated_ingestion/eessi_s3_bucket.py new file mode 100644 index 00000000..bc5a8822 --- /dev/null +++ b/scripts/automated_ingestion/eessi_s3_bucket.py @@ -0,0 +1,191 @@ +import os +from pathlib import Path +from typing import Dict, Optional + +import boto3 +from botocore.exceptions import ClientError +from eessi_logging import log_function_entry_exit, log_message, LoggingScope +from eessi_remote_storage_client import EESSIRemoteStorageClient + + +class EESSIS3Bucket(EESSIRemoteStorageClient): + """EESSI-specific S3 bucket implementation of the EESSIRemoteStorageClient protocol.""" + + @log_function_entry_exit() + def __init__(self, config, bucket_name: str): + """ + Initialize the EESSI S3 bucket. + + Args: + config: Configuration object containing: + - aws.access_key_id: AWS access key ID (optional, can use AWS_ACCESS_KEY_ID env var) + - aws.secret_access_key: AWS secret access key (optional, can use AWS_SECRET_ACCESS_KEY env var) + - aws.endpoint_url: Custom endpoint URL for S3-compatible backends (optional) + - aws.verify: SSL verification setting (optional) + - True: Verify SSL certificates (default) + - False: Skip SSL certificate verification + - str: Path to CA bundle file + bucket_name: Name of the S3 bucket to use + """ + self.bucket = bucket_name + + # get AWS credentials from environment or config + aws_access_key_id = os.getenv("AWS_ACCESS_KEY_ID") or config.get("secrets", "aws_access_key_id") + aws_secret_access_key = os.getenv("AWS_SECRET_ACCESS_KEY") or config.get("secrets", "aws_secret_access_key") + + # configure boto3 client + client_config = {} + + # add endpoint URL if specified in config + if config.has_option("aws", "endpoint_url"): + client_config["endpoint_url"] = config["aws"]["endpoint_url"] + log_message(LoggingScope.DEBUG, "DEBUG", "Using custom endpoint URL: '%s'", client_config["endpoint_url"]) + + # add SSL verification if specified in config + if config.has_option("aws", "verify"): + verify = config["aws"]["verify"] + if verify.lower() == "false": + client_config["verify"] = False + log_message(LoggingScope.DEBUG, "WARNING", "SSL verification disabled") + elif verify.lower() == "true": + client_config["verify"] = True + else: + client_config["verify"] = verify # assume it's a path to CA bundle + log_message(LoggingScope.DEBUG, "DEBUG", "Using custom CA bundle: '%s'", verify) + + self.client = boto3.client( + "s3", + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + **client_config + ) + log_message(LoggingScope.DEBUG, "INFO", "Initialized S3 client for bucket: '%s'", self.bucket) + + @log_function_entry_exit() + def download(self, remote_path: str, local_path: str) -> None: + """ + Download an S3 object to a local location and store its ETag. + + Args: + remote_path: Path to the object in S3 + local_path: Local path where to save the file + """ + try: + log_message(LoggingScope.DOWNLOAD, "INFO", "Downloading '%s' to '%s'", remote_path, local_path) + self.client.download_file(Bucket=self.bucket, Key=remote_path, Filename=local_path) + log_message(LoggingScope.DOWNLOAD, "INFO", "Successfully downloaded '%s' to '%s'", remote_path, local_path) + except ClientError as err: + log_message(LoggingScope.ERROR, "ERROR", "Failed to download '%s': '%s'", remote_path, str(err)) + raise + + # get metadata first to obtain the ETag + metadata = self.get_metadata(remote_path) + etag = metadata["ETag"] + + # store the ETag + self._write_etag(local_path, etag) + + @log_function_entry_exit() + def download_file(self, key: str, filename: str) -> None: + """ + Download a file from S3 to a local file. + + Args: + key: The S3 key of the file to download + filename: The local path where the file should be saved + """ + self.client.download_file(self.bucket, key, filename) + + @log_function_entry_exit() + def get_bucket_url(self) -> str: + """ + Get the HTTPS URL for a bucket from an initialized boto3 client. + Works with both AWS S3 and MinIO/S3-compatible services. + """ + try: + # check if this is a custom endpoint (MinIO) or AWS S3 + endpoint_url = self.client.meta.endpoint_url + + if endpoint_url: + # custom endpoint (MinIO, DigitalOcean Spaces, etc.) + # most S3-compatible services use path-style URLs + bucket_url = f"{endpoint_url}/{self.bucket}" + else: + # AWS S3 (no custom endpoint specified) + region = self.client.meta.region_name or 'us-east-1' + + # AWS S3 virtual-hosted-style URLs + if region == "us-east-1": + bucket_url = f"https://{self.bucket}.s3.amazonaws.com" + else: + bucket_url = f"https://{self.bucket}.s3.{region}.amazonaws.com" + + return bucket_url + + except Exception as err: + log_message(LoggingScope.ERROR, "ERROR", "Error getting bucket URL: '%s'", str(err)) + return None + + @log_function_entry_exit() + def get_metadata(self, remote_path: str) -> Dict: + """ + Get metadata about an S3 object. + + Args: + remote_path: Path to the object in S3 + + Returns: + Dictionary containing object metadata, including 'ETag' key + """ + try: + log_message(LoggingScope.DEBUG, "DEBUG", "Getting metadata for S3 object: '%s'", remote_path) + response = self.client.head_object(Bucket=self.bucket, Key=remote_path) + log_message(LoggingScope.DEBUG, "DEBUG", "Retrieved metadata for '%s': '%s'", remote_path, response) + return response + except ClientError as err: + log_message(LoggingScope.ERROR, "ERROR", "Failed to get metadata for '%s': '%s'", remote_path, str(err)) + raise + + @log_function_entry_exit() + def _get_etag_file_path(self, local_path: str) -> Path: + """Get the path to the .etag file for a given local file.""" + return Path(local_path).with_suffix(".etag") + + @log_function_entry_exit() + def list_objects_v2(self, **kwargs): + """ + List objects in the bucket using the underlying boto3 client. + + Args: + **kwargs: Additional arguments to pass to boto3.client.list_objects_v2 + + Returns: + Response from boto3.client.list_objects_v2 + """ + return self.client.list_objects_v2(Bucket=self.bucket, **kwargs) + + @log_function_entry_exit() + def _read_etag(self, local_path: str) -> Optional[str]: + """Read the ETag from the .etag file if it exists.""" + etag_path = self._get_etag_file_path(local_path) + if etag_path.exists(): + try: + with open(etag_path, "r") as f: + return f.read().strip() + except Exception as e: + log_message(LoggingScope.DEBUG, "WARNING", "Failed to read ETag file '%s': '%s'", etag_path, str(e)) + return None + return None + + @log_function_entry_exit() + def _write_etag(self, local_path: str, etag: str) -> None: + """Write the ETag to the .etag file.""" + etag_path = self._get_etag_file_path(local_path) + try: + with open(etag_path, "w") as f: + f.write(etag) + log_message(LoggingScope.DEBUG, "DEBUG", "Wrote ETag to '%s'", etag_path) + except Exception as err: + log_message(LoggingScope.ERROR, "ERROR", "Failed to write ETag file '%s': '%s'", etag_path, str(err)) + # if we can't write the etag file, it's not critical + # the file will just be downloaded again next time diff --git a/scripts/automated_ingestion/eessi_task.py b/scripts/automated_ingestion/eessi_task.py new file mode 100644 index 00000000..86fbd8df --- /dev/null +++ b/scripts/automated_ingestion/eessi_task.py @@ -0,0 +1,1201 @@ +from enum import Enum, auto +from functools import total_ordering +from typing import Dict, List, Optional, Any + +import base64 +import os +import subprocess +import traceback + +from eessi_data_object import EESSIDataAndSignatureObject +from eessi_logging import log_function_entry_exit, log_message, LoggingScope +from eessi_task_action import EESSITaskAction +from eessi_task_description import EESSITaskDescription +from eessi_task_payload import EESSITaskPayload +from utils import send_slack_message + +from github import Github, GithubException, InputGitTreeElement, UnknownObjectException +from github.Branch import Branch +from github.PullRequest import PullRequest + + +@total_ordering +class EESSITaskState(Enum): + UNDETERMINED = auto() # The task state was not determined yet + NEW_TASK = auto() # The task has been created but not yet processed + PAYLOAD_STAGED = auto() # The task's payload has been staged to the Stratum-0 + PULL_REQUEST = auto() # A PR for the task has been created or updated in some staging repository + APPROVED = auto() # The PR for the task has been approved + REJECTED = auto() # The PR for the task has been rejected + INGESTED = auto() # The task's payload has been applied to the target CernVM-FS repository + DONE = auto() # The task has been completed + + @classmethod + def from_string( + cls, name: str, default: Optional["EESSITaskState"] = None, case_sensitive: bool = False + ) -> "EESSITaskState": + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "from_string: '%s'", name) + if case_sensitive: + to_return = cls.__members__.get(name, default) + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "from_string will return: '%s'", to_return) + return to_return + + try: + to_return = cls[name.upper()] + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "from_string will return: '%s'", to_return) + return to_return + except KeyError: + return default + + def __lt__(self, other): + if self.__class__ is other.__class__: + return self.value < other.value + return NotImplemented + + def __str__(self): + return self.name.upper() + + +class EESSITask: + description: EESSITaskDescription + payload: EESSITaskPayload + action: EESSITaskAction + git_repo: Github + config: Dict + + @log_function_entry_exit() + def __init__(self, description: EESSITaskDescription, config: Dict, cvmfs_repo: str, git_repo: Github): + self.description = description + self.config = config + self.cvmfs_repo = cvmfs_repo + self.git_repo = git_repo + self.action = self._determine_task_action() + + # define valid state transitions for all actions + # NOTE, for EESSITaskState.PULL_REQUEST, EESSITaskState.APPROVED must be the first element or + # _next_state() will not work correctly + self.valid_transitions = { + EESSITaskState.UNDETERMINED: [ + EESSITaskState.NEW_TASK, + EESSITaskState.PAYLOAD_STAGED, + EESSITaskState.PULL_REQUEST, + EESSITaskState.APPROVED, + EESSITaskState.REJECTED, + EESSITaskState.INGESTED, + EESSITaskState.DONE, + ], + EESSITaskState.NEW_TASK: [EESSITaskState.PAYLOAD_STAGED], + EESSITaskState.PAYLOAD_STAGED: [EESSITaskState.PULL_REQUEST], + EESSITaskState.PULL_REQUEST: [EESSITaskState.APPROVED, EESSITaskState.REJECTED], + EESSITaskState.APPROVED: [EESSITaskState.INGESTED], + EESSITaskState.REJECTED: [], # terminal state + EESSITaskState.INGESTED: [], # terminal state + EESSITaskState.DONE: [] # virtual terminal state, not used to write on GitHub + } + + self.payload = None + state = self.determine_state() + if state >= EESSITaskState.PAYLOAD_STAGED: + log_message(LoggingScope.TASK_OPS, "INFO", "initializing payload object in constructor for EESSITask") + self._init_payload_object() + + @log_function_entry_exit() + def _determine_task_action(self) -> EESSITaskAction: + """ + Determine the action type based on task description metadata. + """ + if "task" in self.description.metadata and "action" in self.description.metadata["task"]: + action_str = self.description.metadata["task"]["action"].lower() + if action_str == "nop": + return EESSITaskAction.NOP + elif action_str == "delete": + return EESSITaskAction.DELETE + elif action_str == "add": + return EESSITaskAction.ADD + elif action_str == "update": + return EESSITaskAction.UPDATE + # temporarily return EESSITaskAction.ADD as default because the metadata + # file does not yet have an action defined yet + return EESSITaskAction.ADD + + @log_function_entry_exit() + def _state_file_with_prefix_exists_in_repo_branch(self, file_path_prefix: str, branch_name: str = None) -> bool: + """ + Check if a file exists in a repository branch. + + Args: + file_path_prefix: the prefix of the file path + branch_name: the branch to check + + Returns: + True if a file with the prefix exists in the branch, False otherwise + """ + branch_name = self.git_repo.default_branch if branch_name is None else branch_name + # branch = self._get_branch_from_name(branch_name) + try: + # get all files in directory part of file_path_prefix + directory_part = os.path.dirname(file_path_prefix) + files = self.git_repo.get_contents(directory_part, ref=branch_name) + log_msg = "Found files '%s' in directory '%s' in branch '%s'" + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", log_msg, files, directory_part, branch_name) + # check if any of the files has file_path_prefix as prefix + for file in files: + if file.path.startswith(file_path_prefix): + log_msg = "Found file '%s' in directory '%s' in branch '%s'" + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", log_msg, file.path, directory_part, branch_name) + return True + log_msg = "No file with prefix '%s' found in directory '%s' in branch '%s'" + log_message(LoggingScope.TASK_OPS, "INFO", log_msg, file_path_prefix, directory_part, branch_name) + return False + except UnknownObjectException: + # file_path does not exist in branch + log_msg = "Directory '%s' or file with prefix '%s' does not exist in branch '%s'" + log_message(LoggingScope.TASK_OPS, "INFO", log_msg, directory_part, file_path_prefix, branch_name) + return False + except GithubException as err: + if err.status == 404: + # file_path does not exist in branch + log_msg = "Directory '%s' or file with prefix '%s' does not exist in branch '%s'" + log_message(LoggingScope.TASK_OPS, "INFO", log_msg, directory_part, file_path_prefix, branch_name) + return False + else: + # if there was some other (e.g. connection) issue, log message and return False + log_msg = "Unable to determine the state of '%s', the GitHub API returned status '%s'!" + log_message(LoggingScope.ERROR, "WARNING", log_msg, self.object, err.status) + return False + return False + + @log_function_entry_exit() + def _determine_sequence_numbers_including_task_file(self, repo: str, pr: str) -> Dict[int, bool]: + """ + Determines in which sequence numbers the metadata/task file is included and in which it is not. + NOTE, we only need to check the default branch of the repository, because a for a new task a file + is added to the default branch and for the subsequent processing of the task we use a different branch. + Thus, until the PR is closed, the task file stays in the default branch. + + Args: + repo: the repository name + pr: the pull request number + + Returns: + A dictionary with the sequence numbers as keys and a boolean value indicating if the metadata/task file is + included in that sequence number. + + Idea: + - The deployment for a single source PR could be split into multiple staging PRs each is assigned a unique + sequence number. + - For a given source PR (identified by the repo name and the PR number), a staging PR using a branch named + `REPO/PR_NUM/SEQ_NUM` is created. + - In the staging repo we create a corresponding directory `REPO/PR_NUM/SEQ_NUM`. + - If a metadata/task file is handled by the staging PR with sequence number, it is included in that directory. + - We iterate over all directories under `REPO/PR_NUM`: + - If the metadata/task file is available in the directory, we add the sequence number to the list. + + Note: this is a placeholder for now, as we do not know yet if we need to use a sequence number. + """ + sequence_numbers = {} + repo_pr_dir = f"{repo}/{pr}" + # iterate over all directories under repo_pr_dir + try: + directories = self._list_directory_contents(repo_pr_dir) + for dir in directories: + # check if the directory is a number + if dir.name.isdigit(): + # determine if a state file with prefix exists in the sequence number directory + # we need to use the basename of the remote file path + remote_file_path_basename = os.path.basename(self.description.task_object.remote_file_path) + state_file_name_prefix = f"{repo_pr_dir}/{dir.name}/{remote_file_path_basename}" + if self._state_file_with_prefix_exists_in_repo_branch(state_file_name_prefix): + sequence_numbers[int(dir.name)] = True + else: + sequence_numbers[int(dir.name)] = False + else: + # directory is not a number, so we skip it + continue + except FileNotFoundError: + # repo_pr_dir does not exist, so we return an empty dictionary + return {} + except GithubException as err: + if err.status != 404: # 404 is catched by FileNotFoundError + # some other error than the directory not existing + return {} + return sequence_numbers + + @log_function_entry_exit() + def _list_directory_contents(self, directory_path: str, branch_name: str = None) -> List[Any]: + """ + List the contents of a directory in a branch. + """ + try: + # Get contents of the directory + branch_name = self.git_repo.default_branch if branch_name is None else branch_name + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "listing contents of '%s' in branch '%s'", + directory_path, branch_name) + contents = self.git_repo.get_contents(directory_path, ref=branch_name) + + # If contents is a list, it means we successfully got directory contents + if isinstance(contents, list): + return contents + else: + # If it's not a list, it means the path is not a directory + raise ValueError(f"'{directory_path}' is not a directory") + except GithubException as err: + if err.status == 404: + raise FileNotFoundError(f"Directory not found: '{directory_path}'") + raise err + + @log_function_entry_exit() + def _next_state(self, state: EESSITaskState = None) -> EESSITaskState: + """ + Determine the next state based on the current state using the valid_transitions dictionary. + + NOTE, it assumes that function is only called for non-terminal states and that the next state is the first + element of the list returned by the valid_transitions dictionary. + """ + the_state = state if state is not None else self.determine_state() + return self.valid_transitions[the_state][0] + + @log_function_entry_exit() + def _path_exists_in_branch(self, path: str, branch_name: str = None) -> bool: + """ + Check if a path exists in a branch. + """ + branch_name = self.git_repo.default_branch if branch_name is None else branch_name + try: + self.git_repo.get_contents(path, ref=branch_name) + return True + except GithubException as err: + if err.status == 404: + return False + else: + raise err + + @log_function_entry_exit() + def _read_dict_from_string(self, content: str) -> dict: + """ + Read the dictionary from the string. + """ + config_dict = {} + for line in content.strip().split("\n"): + if "=" in line and not line.strip().startswith("#"): # Skip comments + key, value = line.split("=", 1) # Split only on first '=' + config_dict[key.strip()] = value.strip() + return config_dict + + @log_function_entry_exit() + def _read_pull_request_dir_from_file(self, task_pointer_file: str = None, branch_name: str = None) -> str: + """ + Read the pull request directory from the file in the given branch. + """ + # set default values for task pointer file and branch name + if task_pointer_file is None: + task_pointer_file = self.description.task_object.remote_file_path + if branch_name is None: + branch_name = self.git_repo.default_branch + log_message(LoggingScope.TASK_OPS, "INFO", "reading pull request directory from file '%s' in branch '%s'", + task_pointer_file, branch_name) + + # read the pull request directory from the file in the given branch + content = self.git_repo.get_contents(task_pointer_file, ref=branch_name) + + # Decode the content from base64 + content_str = content.decoded_content.decode("utf-8") + + # Parse into dictionary + config_dict = self._read_dict_from_string(content_str) + + target_dir = config_dict.get("target_dir", None) + return config_dict.get("pull_request_dir", target_dir) + + @log_function_entry_exit() + def _determine_pull_request_dir(self, task_pointer_file: str = None, branch_name: str = None) -> str: + """Determine the pull request directory via the task pointer file""" + return self._read_pull_request_dir_from_file(task_pointer_file=task_pointer_file, branch_name=branch_name) + + @log_function_entry_exit() + def _get_branch_from_name(self, branch_name: str = None) -> Optional[Branch]: + """ + Get a branch object from its name. + """ + branch_name = self.git_repo.default_branch if branch_name is None else branch_name + + try: + branch = self.git_repo.get_branch(branch_name) + log_message(LoggingScope.TASK_OPS, "INFO", "branch '%s' exists: '%s'", branch_name, branch) + return branch + except Exception as err: + log_message(LoggingScope.TASK_OPS, "ERROR", "error checking if branch '%s' exists: '%s'", + branch_name, err) + return None + + @log_function_entry_exit() + def _read_task_state_from_file(self, path: str, branch_name: str = None) -> EESSITaskState: + """ + Read the task state from the file in the given branch. + """ + branch_name = self.git_repo.default_branch if branch_name is None else branch_name + content = self.git_repo.get_contents(path, ref=branch_name) + + # Decode the content from base64 + content_str = content.decoded_content.decode("utf-8").strip() + log_message(LoggingScope.TASK_OPS, "INFO", "content in TaskState file: '%s'", content_str) + + task_state = EESSITaskState.from_string(content_str) + log_message(LoggingScope.TASK_OPS, "INFO", "task state: '%s'", task_state) + + return task_state + + @log_function_entry_exit() + def determine_state(self, branch: str = None) -> EESSITaskState: + """ + Determine the state of the task based on the state of the staging repository. + """ + # check if path representing the task file exists in the default branch or the "feature" branch + task_pointer_file = self.description.task_object.remote_file_path + branch_to_use = self.git_repo.default_branch if branch is None else branch + + if self._path_exists_in_branch(task_pointer_file, branch_name=branch_to_use): + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "path '%s' exists in branch '%s'", + task_pointer_file, branch_to_use) + + # get state from task file in branch to use + # - read the EESSITaskState file in pull request directory + pull_request_dir = self._determine_pull_request_dir(branch_name=branch_to_use) + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "pull request directory: '%s'", pull_request_dir) + task_state_file_path = f"{pull_request_dir}/TaskState" + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "task state file path: '%s'", task_state_file_path) + task_state = self._read_task_state_from_file(task_state_file_path, branch_to_use) + + log_message(LoggingScope.TASK_OPS, "INFO", "task state in branch '%s': %s", + branch_to_use, task_state) + return task_state + else: + log_message(LoggingScope.TASK_OPS, "INFO", "path '%s' does not exist in branch '%s'", + task_pointer_file, branch_to_use) + return EESSITaskState.UNDETERMINED + + @log_function_entry_exit() + def handle(self): + """ + Dynamically find and execute the appropriate handler based on action and state. + """ + state_before_handle = self.determine_state() + + # Construct handler method name + handler_name = f"_handle_{self.action}_{str(state_before_handle).lower()}" + + # Check if the handler exists + handler = getattr(self, handler_name, None) + + if handler and callable(handler): + # Execute the handler if it exists + return handler() + else: + # Default behavior for missing handlers + log_message(LoggingScope.TASK_OPS, "ERROR", + "No handler for action '%s' and state '%s' implemented; nothing to be done", + self.action, state_before_handle) + return state_before_handle + + # Implement handlers for ADD action + @log_function_entry_exit() + def _safe_create_file(self, path: str, message: str, content: str, branch_name: str = None): + """Create a file in the given branch.""" + try: + branch_name = self.git_repo.default_branch if branch_name is None else branch_name + existing_file = self.git_repo.get_contents(path, ref=branch_name) + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "File '%s' already exists", path) + return existing_file + except GithubException as err: + if err.status == 404: # File doesn't exist + # Safe to create + return self.git_repo.create_file(path, message, content, branch=branch_name) + else: + raise err # Some other error + + @log_function_entry_exit() + def _create_multi_file_commit(self, files_data, commit_message, branch_name: str = None): + """ + Create a commit with multiple file changes + + files_data: dict with structure: + { + "path/to/file1.txt": { + "content": "file content", + "mode": "100644" # optional, defaults to 100644 + }, + "path/to/file2.py": { + "content": "print('hello')", + "mode": "100644" + } + } + """ + branch_name = self.git_repo.default_branch if branch_name is None else branch_name + ref = self.git_repo.get_git_ref(f"heads/{branch_name}") + current_commit = self.git_repo.get_git_commit(ref.object.sha) + base_tree = current_commit.tree + + # Create tree elements + tree_elements = [] + for file_path, file_info in files_data.items(): + content = file_info["content"] + if isinstance(content, str): + content = content.encode("utf-8") + + blob = self.git_repo.create_git_blob( + base64.b64encode(content).decode("utf-8"), + "base64" + ) + tree_elements.append(InputGitTreeElement( + path=file_path, + mode=file_info.get("mode", "100644"), + type="blob", + sha=blob.sha + )) + + # Create new tree + new_tree = self.git_repo.create_git_tree(tree_elements, base_tree) + + # Create commit + new_commit = self.git_repo.create_git_commit( + commit_message, + new_tree, + [current_commit] + ) + + # Update branch reference + ref.edit(new_commit.sha) + + return new_commit + + @log_function_entry_exit() + def _update_file( + self, file_path: str, new_content: str, commit_message: str, branch_name: str = None + ) -> Optional[Dict]: + try: + branch_name = self.git_repo.default_branch if branch_name is None else branch_name + + # get the current file + file = self.git_repo.get_contents(file_path, ref=branch_name) + + # update the file + result = self.git_repo.update_file( + path=file_path, + message=commit_message, + content=new_content, + sha=file.sha, + branch=branch_name + ) + + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", + "File updated successfully. Commit SHA: '%s'", result["commit"].sha) + return result + + except Exception as err: + log_message(LoggingScope.TASK_OPS, "ERROR", "Error updating file: '%s'", err) + return None + + @log_function_entry_exit() + def _sorted_list_of_sequence_numbers(self) -> List[int]: + """Create a sorted list of sequence numbers from the pull requests directory""" + # a pull request's directory is of the form REPO/PR/SEQ + # hence, we can get all sequence numbers from the pull requests directory REPO/PR + sequence_numbers = [] + repo_pr_dir = f"{self.description.get_repo_name()}/{self.description.get_pr_number()}" + + # iterate over all directories under repo_pr_dir + try: + directories = self._list_directory_contents(repo_pr_dir) + for dir in directories: + # check if the directory is a number + if dir.name.isdigit(): + sequence_numbers.append(int(dir.name)) + else: + # directory is not a number, so we skip it + continue + except FileNotFoundError: + # repo_pr_dir does not exist, so we return an empty dictionary + log_message(LoggingScope.TASK_OPS, "ERROR", "Pull requests directory '%s' does not exist", repo_pr_dir) + except GithubException as err: + if err.status != 404: # 404 is catched by FileNotFoundError + # some other error than the directory not existing + log_message(LoggingScope.TASK_OPS, "ERROR", + "Some other error than the directory not existing: '%s'", err) + except Exception as err: + log_message(LoggingScope.TASK_OPS, "ERROR", "Unexpected error: '%s'", err) + + return sorted(sequence_numbers) + + @log_function_entry_exit() + def _determine_sequence_number(self) -> int: + """Determine the sequence number for the task""" + + sequence_numbers = self._sorted_list_of_sequence_numbers() + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "number of sequence numbers: %d", len(sequence_numbers)) + if len(sequence_numbers) == 0: + log_message(LoggingScope.TASK_OPS, "INFO", "no sequence numbers found, returning 0") + return 0 + + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", + "sequence numbers: [%s]", ", ".join(map(str, sequence_numbers))) + + # get the highest sequence number + highest_sequence_number = sequence_numbers[-1] + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "highest sequence number: %d", highest_sequence_number) + + pull_request = self._find_pr_for_sequence_number(highest_sequence_number) + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "pull request: '%s'", pull_request) + + if pull_request is None: + log_message(LoggingScope.TASK_OPS, "INFO", "Did not find pull request for sequence number %d", + highest_sequence_number) + # the directory for the sequence number exists but no PR yet + return highest_sequence_number + else: + log_message(LoggingScope.TASK_OPS, "INFO", "pull request found: '%s'", pull_request) + log_message(LoggingScope.TASK_OPS, "INFO", "pull request state/merged: '%s/%s'", + pull_request.state, str(pull_request.is_merged())) + if pull_request.is_merged(): + # the PR is merged, so we use the next sequence number + return highest_sequence_number + 1 + else: + # the PR is not merged, it may be closed though + if pull_request.state == 'closed': + # PR has been closed, so we return the next sequence number + return highest_sequence_number + 1 + else: + # PR is not closed, so we return the current highest sequence number + return highest_sequence_number + + @log_function_entry_exit() + def _handle_add_undetermined(self): + """Handler for ADD action in UNDETERMINED state""" + log_message(LoggingScope.TASK_OPS, "INFO", "Handling ADD action in UNDETERMINED state: '%s'", + self.description.get_task_file_name()) + # task is in state UNDETERMINED if there is no pull request directory for the task yet + # + # create pull request directory (REPO/PR/SEQ/TASK_FILE_NAME/) + # create task file in pull request directory (PULL_REQUEST_DIR/TaskDescription) + # create task status file in pull request directory (PULL_REQUEST_DIR/TaskState) + # create pointer file from task file path to pull request directory (remote_file_path -> PULL_REQUEST_DIR) + repo_name = self.description.get_repo_name() + pr_number = self.description.get_pr_number() + sequence_number = self._determine_sequence_number() # corresponds to an open or yet to be created PR + task_file_name = self.description.get_task_file_name() + # we cannot use self._determine_pull_request_dir() here because it requires a task pointer file + # and we don't have one yet + pull_request_dir = f"{repo_name}/{pr_number}/{sequence_number}/{task_file_name}" + task_description_file_path = f"{pull_request_dir}/TaskDescription" + task_state_file_path = f"{pull_request_dir}/TaskState" + remote_file_path = self.description.task_object.remote_file_path + + files_to_commit = { + task_description_file_path: { + "content": self.description.get_contents(), + "mode": "100644" + }, + task_state_file_path: { + "content": f"{EESSITaskState.NEW_TASK.name}\n", + "mode": "100644" + }, + remote_file_path: { + "content": f"remote_file_path = {remote_file_path}\npull_request_dir = {pull_request_dir}", + "mode": "100644" + } + } + + branch_name = self.git_repo.default_branch + try: + commit = self._create_multi_file_commit( + files_to_commit, + f"new task for {repo_name} PR {pr_number} seq {sequence_number}", + branch_name=branch_name + ) + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "commit created: '%s'", commit) + except Exception as err: + log_message(LoggingScope.TASK_OPS, "ERROR", "Error creating commit: '%s'", err) + # TODO: rollback previous changes (task description file, task state file) + return EESSITaskState.UNDETERMINED + + # TODO: verify that the sequence number is still valid (PR corresponding to the sequence number + # is still open or yet to be created); if it is not valid, perform corrective actions + return EESSITaskState.NEW_TASK + + @log_function_entry_exit() + def _update_task_state_file(self, next_state: EESSITaskState, branch_name: str = None) -> Optional[Dict]: + """Update the TaskState file content in default or given branch""" + branch_name = self.git_repo.default_branch if branch_name is None else branch_name + + task_pointer_file = self.description.task_object.remote_file_path + pull_request_dir = self._read_pull_request_dir_from_file(task_pointer_file, branch_name) + task_state_file_path = f"{pull_request_dir}/TaskState" + arch = self.description.get_metadata_filename_components()[3] + commit_message = f"change task state to {next_state} in {branch_name} for {arch}" + result = self._update_file(task_state_file_path, + f"{next_state.name}\n", + commit_message, + branch_name=branch_name) + return result + + @log_function_entry_exit() + def _init_payload_object(self): + """Initialize the payload object""" + if self.payload is not None: + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "payload object already initialized") + return + + # get name of of payload from metadata + payload_name = self.description.metadata["payload"]["filename"] + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "payload_name: '%s'", payload_name) + + # get config and remote_client from self.description.task_object + config = self.description.task_object.config + remote_client = self.description.task_object.remote_client + + # determine remote_file_path by replacing basename of remote_file_path in self.description.task_object + # with payload_name + description_remote_file_path = self.description.task_object.remote_file_path + payload_remote_file_path = os.path.join(os.path.dirname(description_remote_file_path), payload_name) + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "payload_remote_file_path: '%s'", payload_remote_file_path) + + # initialize payload object + payload_object = EESSIDataAndSignatureObject(config, payload_remote_file_path, remote_client) + self.payload = EESSITaskPayload(payload_object) + log_message(LoggingScope.TASK_OPS, "INFO", "payload: '%s'", self.payload) + + @log_function_entry_exit() + def _handle_add_new_task(self): + """Handler for ADD action in NEW_TASK state""" + log_message(LoggingScope.TASK_OPS, "INFO", "Handling ADD action in NEW_TASK state: '%s'", + self.description.get_task_file_name()) + # determine next state + next_state = self._next_state(EESSITaskState.NEW_TASK) + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "next_state: '%s'", next_state) + + # initialize payload object + self._init_payload_object() + + # update TaskState file content + self._update_task_state_file(next_state) + + # TODO: verify that the sequence number is still valid (PR corresponding to the sequence number + # is still open or yet to be created); if it is not valid, perform corrective actions + return next_state + + @log_function_entry_exit() + def _find_pr_for_branch(self, branch_name: str) -> Optional[PullRequest]: + """ + Find the single PR for the given branch in any state. + + Args: + repo: GitHub repository + branch_name: Name of the branch + + Returns: + PullRequest object if found, None otherwise + """ + try: + prs = [pr for pr in list(self.git_repo.get_pulls(state="all")) + if pr.head.ref == branch_name] + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "number of PRs found: %d", len(prs)) + if len(prs): + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", + "1st PR found: %d, '%s'", prs[0].number, prs[0].head.ref) + return prs[0] if prs else None + except Exception as err: + log_message(LoggingScope.TASK_OPS, "ERROR", "Error finding PR for branch '%s': '%s'", branch_name, err) + return None + + @log_function_entry_exit() + def _find_pr_for_sequence_number(self, sequence_number: int) -> Optional[PullRequest]: + """Find the PR for the given sequence number""" + repo_name = self.description.get_repo_name() + pr_number = self.description.get_pr_number() + feature_branch_name = f"{repo_name.replace('/', '-')}-PR-{pr_number}-SEQ-{sequence_number}" + + # list all PRs with head_ref starting with the feature branch name without the sequence number + last_dash = feature_branch_name.rfind("-") + if last_dash != -1: + head_ref_wout_seq_num = feature_branch_name[:last_dash + 1] # +1 to include the separator + else: + head_ref_wout_seq_num = feature_branch_name + + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", + "searching for PRs whose head_ref starts with: '%s'", head_ref_wout_seq_num) + + all_prs = [pr for pr in list(self.git_repo.get_pulls(state="all")) + if pr.head.ref.startswith(head_ref_wout_seq_num)] + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", " number of PRs found: %d", len(all_prs)) + for pr in all_prs: + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", " PR #%d: '%s'", pr.number, pr.head.ref) + + # now, find the PR for the feature branch name (if any) + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", + "searching PR for feature branch name: '%s'", feature_branch_name) + pull_request = self._find_pr_for_branch(feature_branch_name) + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "pull request for branch '%s': '%s'", + feature_branch_name, pull_request) + return pull_request + + @log_function_entry_exit() + def _determine_sequence_number_from_pull_request_directory(self) -> int: + """Determine the sequence number from the pull request directory name""" + task_pointer_file = self.description.task_object.remote_file_path + pull_request_dir = self._read_pull_request_dir_from_file(task_pointer_file, self.git_repo.default_branch) + # pull_request_dir is of the form REPO/PR/SEQ/TASK_FILE_NAME/ (REPO contains a '/' separating the org and repo) + _, _, _, seq, _ = pull_request_dir.split("/") + return int(seq) + + @log_function_entry_exit() + def _determine_feature_branch_name(self) -> str: + """Determine the feature branch name from the pull request directory name""" + task_pointer_file = self.description.task_object.remote_file_path + pull_request_dir = self._read_pull_request_dir_from_file(task_pointer_file, self.git_repo.default_branch) + # pull_request_dir is of the form REPO/PR/SEQ/TASK_FILE_NAME/ (REPO contains a '/' separating the org and repo) + org, repo, pr, seq, _ = pull_request_dir.split("/") + return f"{org}-{repo}-PR-{pr}-SEQ-{seq}" + + @log_function_entry_exit() + def _update_task_states(self, next_state: EESSITaskState, default_branch_name: str, + approved_state: EESSITaskState, feature_branch_name: str): + """ + Update task states in default and feature branches + + States have to be updated in a specific order and in particular the default branch has to be + merged into the feature branch before the feature branch can be updated to avoid a merge conflict. + + Args: + next_state: next state to be applied to the default branch + default_branch_name: name of the default branch + approved_state: state to be applied to the feature branch + feature_branch_name: name of the feature branch + """ + # TODO: add failure handling (capture failures and return them somehow) + + # update TaskState file content + # - next_state in default branch (interpreted as current state) + # - approved_state in feature branch (interpreted as future state, ie, after + # the PR corresponding to the feature branch will be merged) + + # first, update the task state file in the default branch + self._update_task_state_file(next_state, branch_name=default_branch_name) + + # second, merge default branch into feature branch (to avoid a merge conflict) + # TODO: store arch info (CPU+ACCEL) in task/metdata file and then access that rather + # than using a part of the file name + arch = self.description.get_metadata_filename_components()[3] + commit_message = f"merge {default_branch_name} into {feature_branch_name} for {arch}" + self.git_repo.merge( + head=default_branch_name, + base=feature_branch_name, + commit_message=commit_message + ) + + # last, update task state file in feature branch + self._update_task_state_file(approved_state, branch_name=feature_branch_name) + log_message(LoggingScope.TASK_OPS, "INFO", + "TaskState file updated to '%s' in default branch '%s' and to '%s' in feature branch '%s'", + next_state, default_branch_name, approved_state, feature_branch_name) + + @log_function_entry_exit() + def _create_task_summary(self) -> str: + """Analyse contents of current task and create a file for it in the REPO-PR-SEQ directory.""" + + # determine task summary file path in feature branch on GitHub + feature_branch_name = self._determine_feature_branch_name() + pull_request_dir = self._determine_pull_request_dir(branch_name=feature_branch_name) + task_summary_file_path = f"{pull_request_dir}/TaskSummary.html" + + # check if task summary file already exists in repo on GitHub + if self._path_exists_in_branch(task_summary_file_path, feature_branch_name): + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", + "task summary file already exists: '%s'", task_summary_file_path) + task_summary = self.git_repo.get_contents(task_summary_file_path, ref=feature_branch_name) + # return task_summary.decoded_content + return task_summary + + # create task summary + payload_name = self.description.metadata["payload"]["filename"] + payload_summary = self.payload.analyse_contents(self.config) + metadata_contents = self.description.get_contents() + + task_summary = self.config["github"]["task_summary_payload_template"].format( + payload_name=payload_name, + metadata_contents=metadata_contents, + payload_overview=payload_summary + ) + + # create HTML file with task summary in REPO-PR-SEQ directory + # TODO: add failure handling (capture result and act on it) + task_file_name = self.description.get_task_file_name() + commit_message = f"create summary for {task_file_name} in {feature_branch_name}" + self._safe_create_file(task_summary_file_path, commit_message, task_summary, + branch_name=feature_branch_name) + log_message(LoggingScope.TASK_OPS_DETAILS, "INFO", "task summary file created: '%s'", task_summary_file_path) + + # return task summary + return task_summary + + @log_function_entry_exit() + def _create_pr_contents_overview(self) -> str: + """Create a contents overview for the pull request""" + # TODO: implement + feature_branch_name = self._determine_feature_branch_name() + task_pointer_file = self.description.task_object.remote_file_path + pull_request_dir = self._read_pull_request_dir_from_file(task_pointer_file, feature_branch_name) + pr_dir = os.path.dirname(pull_request_dir) + directories = self._list_directory_contents(pr_dir, feature_branch_name) + contents_overview = "" + if directories: + contents_overview += "\n" + for directory in directories: + task_summary_file_path = f"{pr_dir}/{directory.name}/TaskSummary.html" + if self._path_exists_in_branch(task_summary_file_path, feature_branch_name): + file_contents = self.git_repo.get_contents(task_summary_file_path, ref=feature_branch_name) + task_summary = base64.b64decode(file_contents.content).decode("utf-8") + contents_overview += f"{task_summary}\n" + else: + contents_overview += f"Task summary file not found: {task_summary_file_path}\n" + contents_overview += "\n" + else: + contents_overview += "No tasks found in this PR\n" + + print(f"contents_overview: {contents_overview}") + return contents_overview + + @log_function_entry_exit() + def _create_pull_request(self, feature_branch_name: str, default_branch_name: str): + """ + Create a PR from the feature branch to the default branch + + Args: + feature_branch_name: name of the feature branch + default_branch_name: name of the default branch + """ + pr_title_format = self.config["github"]["grouped_pr_title"] + pr_body_format = self.config["github"]["grouped_pr_body"] + repo_name = self.description.get_repo_name() + pr_number = self.description.get_pr_number() + pr_url = f"https://github.com/{repo_name}/pull/{pr_number}" + seq_num = self._determine_sequence_number_from_pull_request_directory() + pr_title = pr_title_format.format( + cvmfs_repo=self.cvmfs_repo, + pr=pr_number, + repo=repo_name, + seq_num=seq_num, + ) + self._create_task_summary() + contents_overview = self._create_pr_contents_overview() + pr_body = pr_body_format.format( + cvmfs_repo=self.cvmfs_repo, + pr=pr_number, + pr_url=pr_url, + repo=repo_name, + seq_num=seq_num, + contents=contents_overview, + analysis="