diff --git a/openeo/extra/job_management/__init__.py b/openeo/extra/job_management/__init__.py index b0e448e2e..040a17901 100644 --- a/openeo/extra/job_management/__init__.py +++ b/openeo/extra/job_management/__init__.py @@ -32,12 +32,17 @@ from requests.adapters import HTTPAdapter, Retry from openeo import BatchJob, Connection +from openeo.extra.job_management._thread_worker import ( + _JobManagerWorkerThreadPool, + _JobStartTask, +) from openeo.internal.processes.parse import ( Parameter, Process, parse_remote_process_definition, ) from openeo.rest import OpenEoApiError +from openeo.rest.auth.auth import BearerAuth from openeo.util import LazyLoadCache, deep_get, repr_truncate, rfc3339 _log = logging.getLogger(__name__) @@ -105,6 +110,7 @@ def get_by_status(self, statuses: List[str], max=None) -> pd.DataFrame: """ ... + def _start_job_default(row: pd.Series, connection: Connection, *args, **kwargs): raise NotImplementedError("No 'start_job' callable provided") @@ -186,6 +192,7 @@ def start_job( # Expected columns in the job DB dataframes. # TODO: make this part of public API when settled? + # TODO: move non official statuses to seperate column (not_started, queued_for_start) _COLUMN_REQUIREMENTS: Mapping[str, _ColumnProperties] = { "id": _ColumnProperties(dtype="str"), "backend_name": _ColumnProperties(dtype="str"), @@ -222,6 +229,7 @@ def __init__( datetime.timedelta(seconds=cancel_running_job_after) if cancel_running_job_after is not None else None ) self._thread = None + self._worker_pool = None def add_backend( self, @@ -358,6 +366,7 @@ def start_job_thread(self, start_job: Callable[[], BatchJob], job_db: JobDatabas _log.info(f"Resuming `run_jobs` from existing {job_db}") self._stop_thread = False + self._worker_pool = _JobManagerWorkerThreadPool() def run_loop(): @@ -365,14 +374,19 @@ def run_loop(): stats = collections.defaultdict(int) while ( - sum(job_db.count_by_status(statuses=["not_started", "created", "queued", "running"]).values()) > 0 + sum( + job_db.count_by_status( + statuses=["not_started", "created", "queued", "queued_for_start", "running"] + ).values() + ) + > 0 and not self._stop_thread ): - self._job_update_loop(job_db=job_db, start_job=start_job) + self._job_update_loop(job_db=job_db, start_job=start_job, stats=stats) stats["run_jobs loop"] += 1 + # Show current stats and sleep _log.info(f"Job status histogram: {job_db.count_by_status()}. Run stats: {dict(stats)}") - # Do sequence of micro-sleeps to allow for quick thread exit for _ in range(int(max(1, self.poll_sleep))): time.sleep(1) if self._stop_thread: @@ -391,6 +405,8 @@ def stop_job_thread(self, timeout_seconds: Optional[float] = _UNSET): .. versionadded:: 0.32.0 """ + self._worker_pool.shutdown() + if self._thread is not None: self._stop_thread = True if timeout_seconds is _UNSET: @@ -493,7 +509,16 @@ def run_jobs( # TODO: support user-provided `stats` stats = collections.defaultdict(int) - while sum(job_db.count_by_status(statuses=["not_started", "created", "queued", "running"]).values()) > 0: + self._worker_pool = _JobManagerWorkerThreadPool() + + while ( + sum( + job_db.count_by_status( + statuses=["not_started", "created", "queued_for_start", "queued", "running"] + ).values() + ) + > 0 + ): self._job_update_loop(job_db=job_db, start_job=start_job, stats=stats) stats["run_jobs loop"] += 1 @@ -502,6 +527,9 @@ def run_jobs( time.sleep(self.poll_sleep) stats["sleep"] += 1 + # TODO; run post process after shutdown once more to ensure completion? + self._worker_pool.shutdown() + return stats def _job_update_loop( @@ -524,7 +552,7 @@ def _job_update_loop( not_started = job_db.get_by_status(statuses=["not_started"], max=200).copy() if len(not_started) > 0: # Check number of jobs running at each backend - running = job_db.get_by_status(statuses=["created", "queued", "running"]) + running = job_db.get_by_status(statuses=["created", "queued", "queued_for_start", "running"]) stats["job_db get_by_status"] += 1 per_backend = running.groupby("backend_name").size().to_dict() _log.info(f"Running per backend: {per_backend}") @@ -541,7 +569,9 @@ def _job_update_loop( stats["job_db persist"] += 1 total_added += 1 - # Act on jobs + self._process_threadworker_updates(self._worker_pool, job_db, stats) + + # TODO: move this back closer to the `_track_statuses` call above, once job done/error handling is also handled in threads? for job, row in jobs_done: self.on_job_done(job, row) @@ -551,7 +581,6 @@ def _job_update_loop( for job, row in jobs_cancel: self.on_job_cancel(job, row) - def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = None): """Helper method for launching jobs @@ -598,6 +627,7 @@ def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = No df.loc[i, "start_time"] = rfc3339.now_utc() if job: df.loc[i, "id"] = job.job_id + _log.info(f"Job created: {job.job_id}") with ignore_connection_errors(context="get status"): status = job.status() stats["job get status"] += 1 @@ -605,19 +635,93 @@ def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = No if status == "created": # start job if not yet done by callback try: - job.start() - stats["job start"] += 1 - df.loc[i, "status"] = job.status() - stats["job get status"] += 1 + job_con = job.connection + task = _JobStartTask( + root_url=job_con.root_url, + bearer_token=job_con.auth.bearer if isinstance(job_con.auth, BearerAuth) else None, + job_id=job.job_id, + df_idx = i + ) + _log.info(f"Submitting task {task} to thread pool") + self._worker_pool.submit_task(task) + + stats["job_queued_for_start"] += 1 + df.loc[i, "status"] = "queued_for_start" except OpenEoApiError as e: - _log.error(e) - df.loc[i, "status"] = "start_failed" - stats["job start error"] += 1 + _log.info(f"Failed submitting task {task} to thread pool with error: {e}") + df.loc[i, "status"] = "queued_for_start_failed" + stats["job queued for start failed"] += 1 else: # TODO: what is this "skipping" about actually? df.loc[i, "status"] = "skipped" stats["start_job skipped"] += 1 + def _process_threadworker_updates( + self, + worker_pool: '_JobManagerWorkerThreadPool', + job_db: 'JobDatabaseInterface', + stats: Dict[str, int], + ) -> None: + """ + Fetches completed TaskResult objects from the worker pool and applies + their db_update and stats_updates. Only existing DataFrame rows + (matched by df_idx) are upserted via job_db.persist(). Any results + targeting unknown df_idx indices are logged as errors but not persisted. + + + + :param worker_pool: Thread-pool managing asynchronous Task executes + :param job_db: Interface to append/upsert to the job database + :param stats: Dictionary accumulating statistic counters + """ + # Retrieve completed task results immediately + results, _ = worker_pool.process_futures(timeout=0) + + # Collect update dicts + updates: List[Dict[str, Any]] = [] + for res in results: + # Process database updates + if res.db_update: + try: + updates.append({ + 'id': res.job_id, + 'df_idx': res.df_idx, + **res.db_update, + }) + except Exception as e: + _log.error(f"Skipping invalid db_update '{res.db_update}' for job '{res.job_id}': {e}", ) + + # Process stats updates + if res.stats_update: + try: + for key, val in res.stats_update.items(): + count = int(val) + stats[key] = stats.get(key, 0) + count + except Exception as e: + _log.error( + f"Skipping invalid stats_update {res.stats_update} for job '{res.job_id}': {e}" + ) + + # No valid updates: nothing to persist + if not updates: + return + + # Build DataFrame of updates indexed by df_idx + df_updates = pd.DataFrame(updates).set_index('df_idx', drop=True) + + # Determine which rows to upsert + existing_indices = set(df_updates.index).intersection(job_db.read().index) + if existing_indices: + df_upsert = df_updates.loc[sorted(existing_indices)] + job_db.persist(df_upsert) + stats['job_db persist'] = stats.get('job_db persist', 0) + 1 + + # Any df_idx not in original index are errors + missing = set(df_updates.index) - existing_indices + if missing: + _log.error(f"Skipping non-existing dataframe indiches: {sorted(missing)}") + + def on_job_done(self, job: BatchJob, row): """ Handles jobs that have finished. Can be overridden to provide custom behaviour. @@ -673,7 +777,7 @@ def _cancel_prolonged_job(self, job: BatchJob, row): try: # Ensure running start time is valid job_running_start_time = rfc3339.parse_datetime(row.get("running_start_time"), with_timezone=True) - + # Parse the current time into a datetime object with timezone info current_time = rfc3339.parse_datetime(rfc3339.now_utc(), with_timezone=True) @@ -681,12 +785,11 @@ def _cancel_prolonged_job(self, job: BatchJob, row): elapsed = current_time - job_running_start_time if elapsed > self._cancel_running_job_after: - _log.info( f"Cancelling long-running job {job.job_id} (after {elapsed}, running since {job_running_start_time})" ) job.stop() - + except Exception as e: _log.error(f"Unexpected error while handling job {job.job_id}: {e}") @@ -715,7 +818,7 @@ def _track_statuses(self, job_db: JobDatabaseInterface, stats: Optional[dict] = """ stats = stats if stats is not None else collections.defaultdict(int) - active = job_db.get_by_status(statuses=["created", "queued", "running"]).copy() + active = job_db.get_by_status(statuses=["created", "queued", "queued_for_start", "running"]).copy() jobs_done = [] jobs_error = [] @@ -737,7 +840,7 @@ def _track_statuses(self, job_db: JobDatabaseInterface, stats: Optional[dict] = f"Status of job {job_id!r} (on backend {backend_name}) is {new_status!r} (previously {previous_status!r})" ) - if new_status == "finished": + if previous_status != "finished" and new_status == "finished": stats["job finished"] += 1 jobs_done.append((the_job, active.loc[i])) @@ -749,7 +852,7 @@ def _track_statuses(self, job_db: JobDatabaseInterface, stats: Optional[dict] = stats["job canceled"] += 1 jobs_cancel.append((the_job, active.loc[i])) - if previous_status in {"created", "queued"} and new_status == "running": + if previous_status in {"created", "queued", "queued_for_start"} and new_status == "running": stats["job started running"] += 1 active.loc[i, "running_start_time"] = rfc3339.now_utc() @@ -873,11 +976,12 @@ def get_by_status(self, statuses, max=None) -> pd.DataFrame: def _merge_into_df(self, df: pd.DataFrame): if self._df is not None: - self._df.update(df, overwrite=True) + self._df.update(df, overwrite=True) else: self._df = df + class CsvJobDatabase(FullDataFrameJobDatabase): """ Persist/load job metadata with a CSV file. diff --git a/openeo/extra/job_management/_thread_worker.py b/openeo/extra/job_management/_thread_worker.py new file mode 100644 index 000000000..57171be1d --- /dev/null +++ b/openeo/extra/job_management/_thread_worker.py @@ -0,0 +1,206 @@ +""" +Internal utilities to handle job management tasks through threads. +""" + +import concurrent.futures +import logging +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Union + +import openeo + +_log = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class _TaskResult: + """ + Container for the result of a task execution. + Used to communicate the outcome of job-related tasks. + + :param job_id: + The ID of the job this result is associated with. + + :param df_idx: + The index of job's row in the dataframe. + + :param db_update: + Optional dictionary describing updates to apply to a job database, + such as status changes. Defaults to an empty dict. + + :param stats_update: + Optional dictionary capturing statistical counters or metrics, + e.g., number of successful starts or errors. Defaults to an empty dict. + """ + + job_id: str # Mandatory + df_idx: int # Mandatory + db_update: Dict[str, Any] = field(default_factory=dict) # Optional + stats_update: Dict[str, int] = field(default_factory=dict) # Optional + + +@dataclass(frozen=True) +class Task(ABC): + """ + Abstract base class for a unit of work associated with a job (identified by a job id) + and to be processed by :py:classs:`_JobManagerWorkerThreadPool`. + + Because the work is intended to be executed in a thread/process pool, + it is recommended to keep the state of the task object as simple/immutable as possible + (e.g. just some string/number attributes) and avoid sharing complex objects and state. + + The main API for subclasses to implement is the `execute`method + which should return a :py:class:`_TaskResult` object. + with job-related metadata and updates. + + :param job_id: + Identifier of the job to start on the backend. + + :param df_idx: + Index of the row of the job in the dataframe. + + """ + + job_id: str + df_idx: int + + @abstractmethod + def execute(self) -> _TaskResult: + """Execute the task and return a raw result""" + pass + + +@dataclass(frozen=True) +class ConnectedTask(Task): + """ + Base class for tasks that involve an (authenticated) connection to a backend. + + Backend is specified by a root URL, + and (optional) authentication is done through an openEO-style bearer token. + + :param root_url: + The root URL of the OpenEO backend to connect to. + + :param bearer_token: + Optional Bearer token used for authentication. + + """ + + root_url: str + bearer_token: Optional[str] + + def get_connection(self) -> openeo.Connection: + connection = openeo.connect(self.root_url) + if self.bearer_token: + connection.authenticate_bearer_token(self.bearer_token) + return connection + + +class _JobStartTask(ConnectedTask): + """ + Task for starting an openEO batch job (the `POST /jobs//result` request). + """ + + def execute(self) -> _TaskResult: + """ + Start job identified by `job_id` on the backend. + + :returns: + A `_TaskResult` with status and statistics metadata, indicating + success or failure of the job start. + """ + # TODO: move main try-except block to base class? + try: + job = self.get_connection().job(self.job_id) + # TODO: only start when status is "queued"? + job.start() + _log.info(f"Job {self.job_id!r} started successfully") + return _TaskResult( + job_id=self.job_id, + df_idx = self.df_idx, + db_update={"status": "queued"}, + stats_update={"job start": 1}, + ) + except Exception as e: + _log.error(f"Failed to start job {self.job_id!r}: {e!r}") + # TODO: more insights about the failure (e.g. the exception) are just logged, but lost from the result + return _TaskResult( + job_id=self.job_id, + df_idx = self.df_idx, + db_update={"status": "start_failed"}, + stats_update={"start_job error": 1} + ) + + +class _JobManagerWorkerThreadPool: + """ + Thread pool-based worker that manages the execution of asynchronous tasks. + + Internally wraps a `ThreadPoolExecutor` and manages submission, + tracking, and result processing of tasks. + + :param max_workers: + Maximum number of concurrent threads to use for execution. + Defaults to 2. + """ + + def __init__(self, max_workers: int = 2): + self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) + self._future_task_pairs: List[Tuple[concurrent.futures.Future, Task]] = [] + + def submit_task(self, task: Task) -> None: + """ + Submit a task to the thread pool executor. + + Tasks are scheduled for asynchronous execution and tracked + internally to allow later processing of their results. + + :param task: + An instance of `Task` to be executed. + """ + future = self._executor.submit(task.execute) + self._future_task_pairs.append((future, task)) # Track pairs + + def process_futures(self, timeout: Union[float, None] = 0) -> Tuple[List[_TaskResult], int]: + """ + Checks state of futures and collect results from completed ones. + + :param timeout: whether to wait for futures to complete or not: + - 0: don't wait, just return current state. + - non-zero value: wait for that many seconds to allow futures to complete, + - None: (not recommended) wait indefinitely. + + :returns: + Tuple of two elements: list of `_TaskResult` objects from completed tasks. + and number of remaining tasks that are still in progress. + """ + results = [] + to_keep = [] + + done, _ = concurrent.futures.wait([f for f, _ in self._future_task_pairs], timeout=timeout) + + for future, task in self._future_task_pairs: + if future in done: + try: + result = future.result() + except Exception as e: + _log.exception(f"Threaded task {task!r} failed: {e!r}") + result = _TaskResult( + job_id=task.job_id, + df_idx = task.df_idx, + db_update={"status": "threaded task failed"}, + stats_update={"threaded task failed": 1}, + ) + results.append(result) + else: + to_keep.append((future, task)) + _log.info("process_futures: %d tasks done, %d tasks remaining", len(results), len(to_keep)) + + self._future_task_pairs = to_keep + return results, len(to_keep) + + def shutdown(self) -> None: + """Shuts down the thread pool gracefully.""" + _log.info("Shutting down thread pool") + self._executor.shutdown(wait=True) diff --git a/openeo/rest/_testing.py b/openeo/rest/_testing.py index a88eb1177..7c3196d86 100644 --- a/openeo/rest/_testing.py +++ b/openeo/rest/_testing.py @@ -34,6 +34,7 @@ class DummyBackend: # TODO: move to openeo.testing # TODO: unify "batch_jobs", "batch_jobs_full" and "extra_job_metadata_fields"? # TODO: unify "sync_requests" and "sync_requests_full"? + # TODO: support checking bearer token __slots__ = ( "_requests_mock", @@ -47,6 +48,7 @@ class DummyBackend: "next_result", "next_validation_errors", "_forced_job_status", + "_fail_on_job_start", "job_status_updater", "job_id_generator", "extra_job_metadata_fields", @@ -72,6 +74,7 @@ def __init__( self.next_validation_errors = [] self.extra_job_metadata_fields = [] self._forced_job_status: Dict[str, str] = {} + self._fail_on_job_start = {} # Job status update hook: # callable that is called on starting a job, and getting job metadata @@ -220,24 +223,52 @@ def _get_job_id(self, request) -> str: assert job_id in self.batch_jobs return job_id + def _set_job_status(self, job_id: str, status: str): + """Forced override of job status (e.g. for "canceled" or "error")""" + self.batch_jobs[job_id]["status"] = self._forced_job_status[job_id] = status + def _get_job_status(self, job_id: str, current_status: str) -> str: if job_id in self._forced_job_status: return self._forced_job_status[job_id] return self.job_status_updater(job_id=job_id, current_status=current_status) + def setup_job_start_failure( + self, + *, + job_id: Union[str, None] = None, + status_code: int = 500, + response_body: Union[None, str, dict] = None, + ): + """ + Setup for failure when starting a job. + :param job_id: job id to fail on, or None (wildcard) for all jobs + """ + if response_body is None: + response_body = {"code": "Internal", "message": "No job starting for you, buddy"} + if not isinstance(response_body, bytes): + response_body = json.dumps(response_body).encode("utf-8") + self._fail_on_job_start[job_id] = {"status_code": status_code, "response_body": response_body} + def _handle_post_job_results(self, request, context): """Handler of `POST /job/{job_id}/results` (start batch job).""" job_id = self._get_job_id(request) assert self.batch_jobs[job_id]["status"] == "created" - self.batch_jobs[job_id]["status"] = self._get_job_status( - job_id=job_id, current_status=self.batch_jobs[job_id]["status"] - ) - context.status_code = 202 + + failure = self._fail_on_job_start.get(job_id) or self._fail_on_job_start.get(None) + if not failure: + self.batch_jobs[job_id]["status"] = self._get_job_status( + job_id=job_id, current_status=self.batch_jobs[job_id]["status"] + ) + context.status_code = 202 + else: + self._set_job_status(job_id=job_id, status="error") + context.status_code = failure["status_code"] + return failure["response_body"] def _handle_get_job(self, request, context): """Handler of `GET /job/{job_id}` (get batch job status and metadata).""" job_id = self._get_job_id(request) - # Allow updating status with `job_status_setter` once job got past status "created" + # Allow updating status with `job_status_updater` once job got past status "created" if self.batch_jobs[job_id]["status"] != "created": self.batch_jobs[job_id]["status"] = self._get_job_status( job_id=job_id, current_status=self.batch_jobs[job_id]["status"] @@ -268,8 +299,7 @@ def _handle_get_job_results(self, request, context): def _handle_delete_job_results(self, request, context): """Handler of `DELETE /job/{job_id}/results` (cancel job).""" job_id = self._get_job_id(request) - self.batch_jobs[job_id]["status"] = "canceled" - self._forced_job_status[job_id] = "canceled" + self._set_job_status(job_id=job_id, status="canceled") context.status_code = 204 def _handle_get_job_result_asset(self, request, context): diff --git a/tests/extra/job_management/test_job_management.py b/tests/extra/job_management/test_job_management.py index 26944023b..55aa23ea6 100644 --- a/tests/extra/job_management/test_job_management.py +++ b/tests/extra/job_management/test_job_management.py @@ -1,12 +1,14 @@ +import collections import copy import datetime import json import logging import re import threading +import time from pathlib import Path from time import sleep -from typing import Callable, Union +from typing import Union from unittest import mock import dirty_equals @@ -38,6 +40,11 @@ create_job_db, get_job_db, ) +from openeo.extra.job_management._thread_worker import ( + Task, + _JobManagerWorkerThreadPool, + _TaskResult, +) from openeo.rest._testing import OPENEO_BACKEND, DummyBackend, build_capabilities from openeo.util import rfc3339 from openeo.utils.version import ComparableVersion @@ -81,6 +88,26 @@ def sleep_mock(): yield sleep +class DummyTask(Task): + """ + A Task that simply sleeps and then returns a predetermined _TaskResult. + """ + + def __init__(self, job_id, df_idx, db_update, stats_update): + super().__init__(job_id=job_id, df_idx = df_idx) + self._db_update = db_update or {} + self._stats_update = stats_update or {} + + def execute(self) -> _TaskResult: + + return _TaskResult( + job_id=self.job_id, + df_idx = self.df_idx, + db_update=self._db_update, + stats_update=self._stats_update, + ) + + class TestMultiBackendJobManager: @pytest.fixture @@ -156,7 +183,6 @@ def test_basic(self, tmp_path, job_manager, job_manager_root_dir, sleep_mock): job_db_path = tmp_path / "jobs.csv" job_db = CsvJobDatabase(job_db_path).initialize_from_df(df) - run_stats = job_manager.run_jobs(job_db=job_db, start_job=self._create_year_job) assert run_stats == dirty_equals.IsPartialDict( { @@ -590,14 +616,13 @@ def get_status(job_id, current_status): time_machine.move_to(create_time) job_db_path = tmp_path / "jobs.csv" + # Mock sleep() to not actually sleep, but skip one hour at a time with mock.patch.object(openeo.extra.job_management.time, "sleep", new=lambda s: time_machine.shift(60 * 60)): job_manager.run_jobs(df=df, start_job=self._create_year_job, job_db=job_db_path) final_df = CsvJobDatabase(job_db_path).read() - assert final_df.iloc[0].to_dict() == dirty_equals.IsPartialDict( - id="job-2024", status=expected_status, running_start_time="2024-09-01T10:00:00Z" - ) + assert dirty_equals.IsPartialDict(id="job-2024", status=expected_status) == final_df.iloc[0].to_dict() assert dummy_backend_foo.batch_jobs == { "job-2024": { @@ -644,10 +669,11 @@ def test_status_logging(self, tmp_path, job_manager, job_manager_root_dir, sleep run_stats = job_manager.run_jobs(job_db=job_db, start_job=self._create_year_job) assert run_stats == dirty_equals.IsPartialDict({"start_job call": 5, "job finished": 5}) - needle = re.compile(r"Job status histogram:.*'queued': 4.*Run stats:.*'start_job call': 4") + needle = re.compile(r"Job status histogram:.*'finished': 5.*Run stats:.*'job_queued_for_start': 5") assert needle.search(caplog.text) + @pytest.mark.parametrize( ["create_time", "start_time", "running_start_time", "end_time", "end_status", "cancel_after_seconds"], [ @@ -713,7 +739,7 @@ def get_status(job_id, current_status): # Mock sleep() to skip one hour at a time instead of actually sleeping with mock.patch.object(openeo.extra.job_management.time, "sleep", new=lambda s: time_machine.shift(60 * 60)): job_manager.run_jobs(df=df, start_job=self._create_year_job, job_db=job_db_path) - + final_df = CsvJobDatabase(job_db_path).read() # Validate running_start_time is a valid datetime object @@ -721,6 +747,95 @@ def get_status(job_id, current_status): assert isinstance(rfc3339.parse_datetime(filled_running_start_time), datetime.datetime) + def test_process_threadworker_updates(self, tmp_path, caplog): + pool = _JobManagerWorkerThreadPool(max_workers=2) + stats = collections.defaultdict(int) + + # Submit tasks covering all cases + pool.submit_task(DummyTask("j-0", df_idx=0, db_update={"status": "queued"}, stats_update={"queued": 1})) + pool.submit_task(DummyTask("j-1", df_idx=1, db_update={"status": "queued"}, stats_update=None)) + pool.submit_task(DummyTask("j-2", df_idx=2, db_update=None, stats_update={"queued": 1})) + pool.submit_task(DummyTask("j-3", df_idx=3, db_update=None, stats_update=None)) + # Invalid index (not in DB) + pool.submit_task(DummyTask("j-missing", df_idx=4, db_update={"status": "created"}, stats_update=None)) + + df_initial = pd.DataFrame({ + "id": ["j-0", "j-1", "j-2", "j-3"], + "status": ["created", "created", "created", "created"], + }) + job_db = CsvJobDatabase(tmp_path / "jobs.csv").initialize_from_df(df_initial) + + mgr = MultiBackendJobManager(root_dir=tmp_path / "jobs") + + with caplog.at_level(logging.ERROR): + mgr._process_threadworker_updates(worker_pool=pool, job_db=job_db, stats=stats) + + df_final = job_db.read() + + # Assert no rows were appended + assert len(df_final) == 4 + + # Assert updates + assert df_final.loc[0, "status"] == "queued" + assert df_final.loc[1, "status"] == "queued" + assert df_final.loc[2, "status"] == "created" + assert df_final.loc[3, "status"] == "created" + + # Assert stats + assert stats.get("queued", 0) == 2 + assert stats["job_db persist"] == 1 + + # Assert error log for invalid index + assert any("Skipping non-existing dataframe indiches" in msg for msg in caplog.messages) + + def test_no_results_leaves_db_and_stats_untouched(self, tmp_path, caplog): + pool = _JobManagerWorkerThreadPool(max_workers=2) + stats = collections.defaultdict(int) + + df_initial = pd.DataFrame({"id": ["j-0"], "status": ["created"]}) + job_db = CsvJobDatabase(tmp_path / "jobs.csv").initialize_from_df(df_initial) + mgr = MultiBackendJobManager(root_dir=tmp_path / "jobs") + + mgr._process_threadworker_updates(pool, job_db, stats) + + df_final = job_db.read() + assert df_final.loc[0, "status"] == "created" + assert stats == {} + + + def test_logs_on_invalid_update(self, tmp_path, caplog): + pool = _JobManagerWorkerThreadPool(max_workers=2) + stats = collections.defaultdict(int) + + # Malformed db_update (not a dict unpackable via **) + class BadTask: + job_id = "bad-task" + df_idx = 0 + db_update = "invalid" # invalid + stats_update = "a" + + def execute(self): + return self + + pool.submit_task(BadTask()) + + df_initial = pd.DataFrame({"id": ["j-0"], "status": ["created"]}) + job_db = CsvJobDatabase(tmp_path / "jobs.csv").initialize_from_df(df_initial) + mgr = MultiBackendJobManager(root_dir=tmp_path / "jobs") + + with caplog.at_level(logging.ERROR): + mgr._process_threadworker_updates(pool, job_db, stats) + + # DB should remain unchanged + df_final = job_db.read() + assert df_final.loc[0, "status"] == "created" + + # Stats remain empty + assert stats == {} + + # Assert log about invalid db update + assert any("Skipping invalid db_update" in msg for msg in caplog.messages) + assert any("Skipping invalid stats_update" in msg for msg in caplog.messages) JOB_DB_DF_BASICS = pd.DataFrame( { diff --git a/tests/extra/job_management/test_thread_worker.py b/tests/extra/job_management/test_thread_worker.py new file mode 100644 index 000000000..1687cfc78 --- /dev/null +++ b/tests/extra/job_management/test_thread_worker.py @@ -0,0 +1,278 @@ +import logging +import threading +import time +from dataclasses import dataclass +from typing import Iterator + +import pytest +import requests + +from openeo.extra.job_management._thread_worker import ( + Task, + _JobManagerWorkerThreadPool, + _JobStartTask, + _TaskResult, +) +from openeo.rest._testing import DummyBackend + + +@pytest.fixture +def dummy_backend(requests_mock) -> DummyBackend: + dummy = DummyBackend.at_url("https://foo.test", requests_mock=requests_mock) + dummy.setup_simple_job_status_flow(queued=2, running=5) + return dummy + + +class TestTaskResult: + def test_default(self): + result = _TaskResult(job_id="j-123", df_idx = 0) + assert result.job_id == "j-123" + assert result.df_idx ==0 + assert result.db_update == {} + assert result.stats_update == {} + + +class TestJobStartTask: + def test_start_success(self, dummy_backend, caplog): + caplog.set_level(logging.WARNING) + job = dummy_backend.connection.create_job(process_graph={}) + + task = _JobStartTask(job_id=job.job_id, df_idx=0, root_url=dummy_backend.connection.root_url, bearer_token="h4ll0") + result = task.execute() + + assert result == _TaskResult( + job_id="job-000", + df_idx = 0, + db_update={"status": "queued"}, + stats_update={"job start": 1}, + ) + assert job.status() == "queued" + assert caplog.messages == [] + + def test_start_failure(self, dummy_backend, caplog): + caplog.set_level(logging.WARNING) + job = dummy_backend.connection.create_job(process_graph={}) + dummy_backend.setup_job_start_failure() + + task = _JobStartTask(job_id=job.job_id, df_idx=0, root_url=dummy_backend.connection.root_url, bearer_token="h4ll0") + result = task.execute() + + assert result == _TaskResult( + job_id="job-000", + df_idx=0, + db_update={"status": "start_failed"}, + stats_update={"start_job error": 1}, + ) + assert job.status() == "error" + assert caplog.messages == [ + "Failed to start job 'job-000': OpenEoApiError('[500] Internal: No job starting " "for you, buddy')" + ] + + + + +class NopTask(Task): + """Do Nothing""" + + def execute(self) -> _TaskResult: + return _TaskResult(job_id=self.job_id, df_idx=self.df_idx) + + +class DummyTask(Task): + def execute(self) -> _TaskResult: + if self.job_id == "j-666": + raise ValueError("Oh no!") + return _TaskResult( + job_id=self.job_id, + df_idx=self.df_idx, + db_update={"status": "dummified"}, + stats_update={"dummy": 1}, + ) + + +@dataclass(frozen=True) +class BlockingTask(Task): + """Another dummy task that blocks until an event is set, and optionally fails.""" + + event: threading.Event + timeout: int = 5 + success: bool = True + + def execute(self) -> _TaskResult: + released = self.event.wait(timeout=self.timeout) + if not released: + raise TimeoutError("Waiting for event timed out") + if not self.success: + raise ValueError("Oh no!") + return _TaskResult(job_id=self.job_id, df_idx=self.df_idx, db_update={"status": "all fine"}) + + + + +class TestJobManagerWorkerThreadPool: + @pytest.fixture + def worker_pool(self) -> Iterator[_JobManagerWorkerThreadPool]: + """Fixture for creating and cleaning up a worker thread pool.""" + pool = _JobManagerWorkerThreadPool(max_workers=2) + yield pool + pool.shutdown() + + def test_no_tasks(self, worker_pool): + results, remaining = worker_pool.process_futures(timeout=10) + assert results == [] + assert remaining == 0 + + def test_submit_and_process(self, worker_pool): + worker_pool.submit_task(DummyTask(job_id="j-123", df_idx=0)) + results, remaining = worker_pool.process_futures(timeout=10) + assert results == [ + _TaskResult(job_id="j-123", df_idx=0, db_update={"status": "dummified"}, stats_update={"dummy": 1}), + ] + assert remaining == 0 + + def test_submit_and_process_zero_timeout(self, worker_pool): + worker_pool.submit_task(DummyTask(job_id="j-123", df_idx=0)) + # Trigger context switch + time.sleep(0.1) + results, remaining = worker_pool.process_futures(timeout=0) + assert results == [ + _TaskResult(job_id="j-123", df_idx=0, db_update={"status": "dummified"}, stats_update={"dummy": 1}), + ] + assert remaining == 0 + + def test_submit_and_process_with_error(self, worker_pool): + worker_pool.submit_task(DummyTask(job_id="j-666", df_idx=0)) + results, remaining = worker_pool.process_futures(timeout=10) + assert results == [ + _TaskResult( + job_id="j-666", + df_idx = 0, + db_update={"status": "threaded task failed"}, + stats_update={"threaded task failed": 1}, + ), + ] + assert remaining == 0 + + def test_submit_and_process_iterative(self, worker_pool): + worker_pool.submit_task(NopTask(job_id="j-1", df_idx=1)) + results, remaining = worker_pool.process_futures(timeout=1) + assert results == [_TaskResult(job_id="j-1", df_idx=1)] + assert remaining == 0 + + # Add some more + worker_pool.submit_task(NopTask(job_id="j-22", df_idx=22)) + worker_pool.submit_task(NopTask(job_id="j-222", df_idx=222)) + results, remaining = worker_pool.process_futures(timeout=1) + assert results == [_TaskResult(job_id="j-22", df_idx=22), _TaskResult(job_id="j-222", df_idx=222)] + assert remaining == 0 + + def test_submit_multiple_simple(self, worker_pool): + # A bunch of dummy tasks + for j in range(5): + worker_pool.submit_task(NopTask(job_id=f"j-{j}", df_idx=j)) + + # Process all of them (non-zero timeout, which should be plenty of time for all of them to finish) + results, remaining = worker_pool.process_futures(timeout=1) + expected = [_TaskResult(job_id=f"j-{j}", df_idx=j) for j in range(5)] + assert sorted(results, key=lambda r: r.job_id) == expected + + def test_submit_multiple_blocking_and_failing(self, worker_pool): + # Setup bunch of blocking tasks, some failing + events = [] + n = 5 + for j in range(n): + event = threading.Event() + events.append(event) + worker_pool.submit_task( + BlockingTask( + job_id=f"j-{j}", + df_idx=j, + event=event, + success=j != 3, + ) + ) + + # Initial state: nothing happened yet + results, remaining = worker_pool.process_futures(timeout=0) + assert (results, remaining) == ([], n) + + # No changes even after timeout + results, remaining = worker_pool.process_futures(timeout=0.1) + assert (results, remaining) == ([], n) + + # Set one event and wait for corresponding result + events[0].set() + results, remaining = worker_pool.process_futures(timeout=0.1) + assert results == [ + _TaskResult(job_id="j-0", df_idx = 0, db_update={"status": "all fine"}), + ] + assert remaining == n - 1 + + # Release all but one event + for j in range(n - 1): + events[j].set() + results, remaining = worker_pool.process_futures(timeout=0.1) + assert results == [ + _TaskResult(job_id="j-1", df_idx = 1, db_update={"status": "all fine"}), + _TaskResult(job_id="j-2", df_idx = 2, db_update={"status": "all fine"}), + _TaskResult( + job_id="j-3", df_idx = 3, db_update={"status": "threaded task failed"}, stats_update={"threaded task failed": 1} + ), + ] + assert remaining == 1 + + # Release all events + for j in range(n): + events[j].set() + results, remaining = worker_pool.process_futures(timeout=0.1) + assert results == [ + _TaskResult(job_id="j-4", df_idx = 4, db_update={"status": "all fine"}), + ] + assert remaining == 0 + + def test_shutdown(self, worker_pool): + # Before shutdown + worker_pool.submit_task(NopTask(job_id="j-123", df_idx=0)) + results, remaining = worker_pool.process_futures(timeout=0.1) + assert (results, remaining) == ([_TaskResult(job_id="j-123", df_idx=0)], 0) + + worker_pool.shutdown() + + # After shutdown, no new tasks should be accepted + with pytest.raises(RuntimeError, match="cannot schedule new futures after shutdown"): + worker_pool.submit_task(NopTask(job_id="j-456", df_idx=1)) + + def test_job_start_task(self, worker_pool, dummy_backend, caplog): + caplog.set_level(logging.WARNING) + job = dummy_backend.connection.create_job(process_graph={}) + task = _JobStartTask(job_id=job.job_id, df_idx=0, root_url=dummy_backend.connection.root_url, bearer_token=None) + worker_pool.submit_task(task) + + results, remaining = worker_pool.process_futures(timeout=1) + assert results == [ + _TaskResult( + job_id="job-000", + df_idx = 0, + db_update={"status": "queued"}, + stats_update={"job start": 1}, + ) + ] + assert remaining == 0 + assert caplog.messages == [] + + def test_job_start_task_failure(self, worker_pool, dummy_backend, caplog): + caplog.set_level(logging.WARNING) + dummy_backend.setup_job_start_failure() + + job = dummy_backend.connection.create_job(process_graph={}) + task = _JobStartTask(job_id=job.job_id, df_idx=0, root_url=dummy_backend.connection.root_url, bearer_token=None) + worker_pool.submit_task(task) + + results, remaining = worker_pool.process_futures(timeout=1) + assert results == [ + _TaskResult(job_id="job-000", df_idx=0, db_update={"status": "start_failed"}, stats_update={"start_job error": 1}) + ] + assert remaining == 0 + assert caplog.messages == [ + "Failed to start job 'job-000': OpenEoApiError('[500] Internal: No job starting for you, buddy')" + ] diff --git a/tests/rest/test_testing.py b/tests/rest/test_testing.py index 0c28fb391..589dda3dc 100644 --- a/tests/rest/test_testing.py +++ b/tests/rest/test_testing.py @@ -1,5 +1,8 @@ +import re + import pytest +from openeo.rest import OpenEoApiError from openeo.rest._testing import DummyBackend @@ -94,3 +97,10 @@ def test_setup_simple_job_status_flow_final_per_job(self, dummy_backend, con120) assert job0.status() == "finished" assert job1.status() == "error" assert job2.status() == "finished" + + def test_setup_job_start_failure(self, dummy_backend): + job = dummy_backend.connection.create_job(process_graph={}) + dummy_backend.setup_job_start_failure() + with pytest.raises(OpenEoApiError, match=re.escape("[500] Internal: No job starting for you, buddy")): + job.start() + assert job.status() == "error"