From 13bef97a10ab0d1ff14f418eb881de4f397d6874 Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 11 Oct 2024 09:50:13 +0200 Subject: [PATCH 1/3] Remove idle and saturated sets from scheduler --- distributed/scheduler.py | 120 ++-------------------- distributed/stealing.py | 151 +++++++++++----------------- distributed/tests/test_scheduler.py | 10 +- distributed/tests/test_steal.py | 51 +++++----- 4 files changed, 93 insertions(+), 239 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index f1e2c1f3e3..46b89b26a6 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1770,11 +1770,8 @@ def __init__( self.clients["fire-and-forget"] = ClientState("fire-and-forget") self.extensions = {} self.host_info = host_info - self.idle = SortedDict() - self.idle_task_count = set() self.n_tasks = 0 self.resources = resources - self.saturated = set() self.tasks = tasks self.replicated_tasks = { ts for ts in self.tasks.values() if len(ts.who_has or ()) > 1 @@ -1865,7 +1862,6 @@ def __pdict__(self) -> dict[str, Any]: return { "bandwidth": self.bandwidth, "resources": self.resources, - "saturated": self.saturated, "unrunnable": self.unrunnable, "queued": self.queued, "n_tasks": self.n_tasks, @@ -1879,7 +1875,6 @@ def __pdict__(self) -> dict[str, Any]: "extensions": self.extensions, "clients": self.clients, "workers": self.workers, - "idle": self.idle, "host_info": self.host_info, } @@ -2310,7 +2305,7 @@ def decide_worker_rootish_queuing_disabled( # See root-ish-ness note below in `decide_worker_rootish_queuing_enabled` assert math.isinf(self.WORKER_SATURATION) or not ts._queueable - pool = self.idle.values() if self.idle else self.running + pool = self.running if not pool: return None @@ -2375,22 +2370,16 @@ def decide_worker_rootish_queuing_enabled(self) -> WorkerState | None: # then add that assertion here (and actually pass in the task). assert not math.isinf(self.WORKER_SATURATION) - if not self.idle_task_count: - # All workers busy? Task gets/stays queued. + if not self.running: return None # Just pick the least busy worker. # NOTE: this will lead to worst-case scheduling with regards to co-assignment. - ws = min( - self.idle_task_count, - key=lambda ws: len(ws.processing) / ws.nthreads, - ) + ws = min(self.running, key=lambda ws: len(ws.processing) / ws.nthreads) + if _worker_full(ws, self.WORKER_SATURATION): + return None if self.validate: assert self.workers.get(ws.address) is ws - assert not _worker_full(ws, self.WORKER_SATURATION), ( - ws, - _task_slots_available(ws, self.WORKER_SATURATION), - ) assert ws in self.running, (ws, self.running) return ws @@ -2434,7 +2423,7 @@ def decide_worker_non_rootish(self, ts: TaskState) -> WorkerState | None: # dependencies, but its group is also smaller than the cluster. # Fastpath when there are no related tasks or restrictions - worker_pool = self.idle or self.workers + worker_pool = self.workers # FIXME idle and workers are SortedDict's declared as dicts # because sortedcontainers is not annotated wp_vals = cast("Sequence[WorkerState]", worker_pool.values()) @@ -2927,7 +2916,6 @@ def _transition_waiting_queued(self, key: Key, stimulus_id: str) -> RecsMsgs: ts = self.tasks[key] if self.validate: - assert not self.idle_task_count, (ts, self.idle_task_count) self._validate_ready(ts) ts.state = "queued" @@ -3158,63 +3146,6 @@ def is_rootish(self, ts: TaskState) -> bool: and sum(map(len, tg.dependencies)) < self.rootish_tg_dependencies_threshold ) - def check_idle_saturated(self, ws: WorkerState, occ: float = -1.0) -> None: - """Update the status of the idle and saturated state - - The scheduler keeps track of workers that are .. - - - Saturated: have enough work to stay busy - - Idle: do not have enough work to stay busy - - They are considered saturated if they both have enough tasks to occupy - all of their threads, and if the expected runtime of those tasks is - large enough. - - If ``distributed.scheduler.worker-saturation`` is not ``inf`` - (scheduler-side queuing is enabled), they are considered idle - if they have fewer tasks processing than the ``worker-saturation`` - threshold dictates. - - Otherwise, they are considered idle if they have fewer tasks processing - than threads, or if their tasks' total expected runtime is less than half - the expected runtime of the same number of average tasks. - - This is useful for load balancing and adaptivity. - """ - if self.total_nthreads == 0 or ws.status == Status.closed: - return - if occ < 0: - occ = ws.occupancy - - p = len(ws.processing) - - self.saturated.discard(ws) - if ws.status != Status.running: - self.idle.pop(ws.address, None) - elif self.is_unoccupied(ws, occ, p): - self.idle[ws.address] = ws - else: - self.idle.pop(ws.address, None) - nc = ws.nthreads - if p > nc: - pending = occ * (p - nc) / (p * nc) - if 0.4 < pending > 1.9 * (self.total_occupancy / self.total_nthreads): - self.saturated.add(ws) - - if not _worker_full(ws, self.WORKER_SATURATION) and ws.status == Status.running: - self.idle_task_count.add(ws) - else: - self.idle_task_count.discard(ws) - - def is_unoccupied( - self, ws: WorkerState, occupancy: float, nprocessing: int - ) -> bool: - nthreads = ws.nthreads - return ( - nprocessing < nthreads - or occupancy < nthreads * (self.total_occupancy / self.total_nthreads) / 2 - ) - def get_comm_cost(self, ts: TaskState, ws: WorkerState) -> float: """ Get the estimated communication cost (in s.) to compute the task @@ -3402,7 +3333,6 @@ def _add_to_processing( ts.processing_on = ws ts.state = "processing" self.acquire_resources(ts, ws) - self.check_idle_saturated(ws) self.n_tasks += 1 if ts.actor: @@ -3468,7 +3398,6 @@ def _exit_processing_common(self, ts: TaskState) -> WorkerState | None: if self.workers.get(ws.address) is not ws: # may have been removed return None - self.check_idle_saturated(ws) self.release_resources(ts, ws) return ws @@ -4606,10 +4535,6 @@ async def add_worker( metrics=metrics, ) - # Do not need to adjust self.total_occupancy as self.occupancy[ws] cannot - # exist before this. - self.check_idle_saturated(ws) - self.stream_comms[address] = BatchedSend(interval="5ms", loop=self.loop) awaitables = [] @@ -5227,13 +5152,11 @@ def stimulus_queue_slots_maybe_opened(self, *, stimulus_id: str) -> None: so any tasks that became runnable are already in ``processing``. Otherwise, overproduction can occur if queued tasks get scheduled before downstream tasks. - Must be called after `check_idle_saturated`; i.e. `idle_task_count` must be up to date. """ if not self.queued: return slots_available = sum( - _task_slots_available(ws, self.WORKER_SATURATION) - for ws in self.idle_task_count + _task_slots_available(ws, self.WORKER_SATURATION) for ws in self.running ) if slots_available == 0: return @@ -5466,9 +5389,6 @@ async def remove_worker( self.rpc.remove(address) del self.stream_comms[address] del self.aliases[ws.name] - self.idle.pop(ws.address, None) - self.idle_task_count.discard(ws) - self.saturated.discard(ws) del self.workers[address] self._workers_removed_total += 1 ws.status = Status.closed @@ -5818,23 +5738,6 @@ def validate_state(self, allow_overlap: bool = False) -> None: if not (set(self.workers) == set(self.stream_comms)): raise ValueError("Workers not the same in all collections") - assert self.running.issuperset(self.idle.values()), ( - self.running.copy(), - set(self.idle.values()), - ) - assert self.running.issuperset(self.idle_task_count), ( - self.running.copy(), - self.idle_task_count.copy(), - ) - assert self.running.issuperset(self.saturated), ( - self.running.copy(), - self.saturated.copy(), - ) - assert self.saturated.isdisjoint(self.idle.values()), ( - self.saturated.copy(), - set(self.idle.values()), - ) - task_prefix_counts: defaultdict[str, int] = defaultdict(int) for w, ws in self.workers.items(): assert isinstance(w, str), (type(w), w) @@ -5845,14 +5748,10 @@ def validate_state(self, allow_overlap: bool = False) -> None: assert ws in self.running else: assert ws not in self.running - assert ws.address not in self.idle - assert ws not in self.saturated assert ws.long_running.issubset(ws.processing) if not ws.processing: assert not ws.occupancy - if ws.status == Status.running: - assert ws.address in self.idle assert not ws.needs_what.keys() & ws.has_what actual_needs_what: defaultdict[TaskState, int] = defaultdict(int) for ts in ws.processing: @@ -6136,7 +6035,6 @@ def handle_long_running( ts.prefix.duration_average = (old_duration + compute_duration) / 2 ws.add_to_long_running(ts) - self.check_idle_saturated(ws) self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id) @@ -6164,16 +6062,12 @@ def handle_worker_status_change( if ws.status == Status.running: self.running.add(ws) - self.check_idle_saturated(ws) self.transitions( self.bulk_schedule_unrunnable_after_adding_worker(ws), stimulus_id ) self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id) else: self.running.discard(ws) - self.idle.pop(ws.address, None) - self.idle_task_count.discard(ws) - self.saturated.discard(ws) self._refresh_no_workers_since() def handle_request_refresh_who_has( diff --git a/distributed/stealing.py b/distributed/stealing.py index e3c3ace81c..873ca4ec48 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -7,7 +7,7 @@ from functools import partial from math import log2 from time import time -from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast +from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypedDict, cast from tlz import topk @@ -38,7 +38,6 @@ logger = logging.getLogger(__name__) - LOG_PDB = dask.config.get("distributed.admin.pdb-on-err") _WORKER_STATE_CONFIRM = { @@ -71,7 +70,7 @@ class InFlightInfo(TypedDict): class WorkStealing(SchedulerPlugin): scheduler: Scheduler # {worker: ({ task states for level 0}, ..., {task states for level 14})} - stealable: dict[str, tuple[set[TaskState], ...]] + stealable: dict[str, tuple[dict[TaskState, Literal[None]], ...]] # { task state: (worker, level) } key_stealable: dict[TaskState, tuple[str, int]] # (multiplier for level 0, ... multiplier for level 14) @@ -163,7 +162,7 @@ def log(self, msg: Any) -> None: return self.scheduler.log_event("stealing", msg) def add_worker(self, scheduler: Any = None, worker: Any = None) -> None: - self.stealable[worker] = tuple(set() for _ in range(15)) + self.stealable[worker] = tuple(dict() for _ in range(15)) def remove_worker(self, scheduler: Scheduler, worker: str, **kwargs: Any) -> None: del self.stealable[worker] @@ -245,7 +244,7 @@ def put_key_in_stealable(self, ts: TaskState) -> None: assert ts.processing_on ws = ts.processing_on worker = ws.address - self.stealable[worker][level].add(ts) + self.stealable[worker][level][ts] = None self.key_stealable[ts] = (worker, level) if duration == ts.prefix.duration_average: @@ -262,7 +261,7 @@ def remove_key_from_stealable(self, ts: TaskState) -> None: return worker, level = result - self.stealable[worker][level].discard(ts) + self.stealable[worker][level].pop(ts, None) def steal_time_ratio(self, ts: TaskState) -> tuple[float, int] | tuple[None, None]: """The compute to communication time ratio of a key @@ -378,7 +377,7 @@ def move_task_confirm( assert ts.processing_on == victim try: - _log_msg = [key, state, victim.address, thief.address, stimulus_id] + _log_msg = [key, state, victim.name, thief.name, stimulus_id] if ( state in _WORKER_STATE_UNDEFINED @@ -416,9 +415,6 @@ def move_task_confirm( pdb.set_trace() raise - finally: - self.scheduler.check_idle_saturated(thief) - self.scheduler.check_idle_saturated(victim) @log_errors def balance(self) -> None: @@ -432,26 +428,35 @@ def balance(self) -> None: i = 0 # Paused and closing workers must never become thieves - potential_thieves = set(s.idle.values()) - if not potential_thieves or len(potential_thieves) == len(s.workers): + if len(s.running) < 2: + return + # FIXME: It would be better if this was constant time + potential_thieves = { + ws for ws in s.running if self._combined_nprocessing(ws) <= ws.nthreads + } + if not potential_thieves: return victim: WorkerState | None - potential_victims: set[WorkerState] | list[WorkerState] = s.saturated + potential_victims: set[WorkerState] | list[WorkerState] = ( + s.running - potential_thieves + ) + potential_victims = topk( + 10, + potential_victims, + key=lambda x: (self._combined_nprocessing(x), x.name), + ) + potential_victims = [ + ws + for ws in potential_victims + if self._combined_nprocessing(ws) > ws.nthreads + ] if not potential_victims: - potential_victims = topk(10, s.workers.values(), key=combined_occupancy) - potential_victims = [ - ws - for ws in potential_victims - if combined_occupancy(ws) > 0.2 - and self._combined_nprocessing(ws) > ws.nthreads - and ws not in potential_thieves - ] - if not potential_victims: - return - if len(potential_victims) < 20: - potential_victims = sorted( - potential_victims, key=combined_occupancy, reverse=True - ) + return + potential_victims = sorted( + potential_victims, + key=lambda x: (self._combined_nprocessing(x), x.name), + reverse=True, + ) assert potential_victims assert potential_thieves for level, _ in enumerate(self.cost_multipliers): @@ -462,7 +467,12 @@ def balance(self) -> None: if not stealable or not potential_thieves: continue - for ts in list(stealable): + # We're iterating in reverse order. This way we're basically + # stealing the lowest priority tasks first which gives us a + # better chance for a successful steal. + # Note that this will disturb global ordering of tasks and + # therefore may contribute to memory pressure + for ts in reversed(list(stealable)): if not potential_thieves: break if ( @@ -471,14 +481,10 @@ def balance(self) -> None: or ts not in victim.processing ): # FIXME: Instead of discarding here, clean up stealable properly - stealable.discard(ts) + stealable.pop(ts, None) continue i += 1 - if not ( - thief := self._get_thief( - s, ts, potential_thieves, occupancies=occupancies - ) - ): + if not (thief := _get_thief(s, ts, potential_thieves)): continue occ_thief = combined_occupancy(thief) @@ -490,7 +496,7 @@ def balance(self) -> None: occ_thief + comm_cost_thief + compute <= occ_victim - (comm_cost_victim + compute) / 2 ): - self.move_task_request(ts, victim, thief) + stim_id = self.move_task_request(ts, victim, thief) cost = compute + comm_cost_victim log.append( ( @@ -498,32 +504,22 @@ def balance(self) -> None: level, ts.key, cost, - victim.address, + victim.name, occ_victim, - thief.address, + thief.name, occ_thief, + stim_id, ) ) self.metrics["request_count_total"][level] += 1 self.metrics["request_cost_total"][level] += cost - occ_thief = combined_occupancy(thief) - nproc_thief = self._combined_nprocessing(thief) - - # FIXME: In the worst case, the victim may have 3x the amount of work - # of the thief when this aborts balancing. - if not self.scheduler.is_unoccupied( - thief, occ_thief, nproc_thief - ): + if self._combined_nprocessing(thief) >= thief.nthreads: potential_thieves.discard(thief) # FIXME: move_task_request already implements some logic # for removing ts from stealable. If we made sure to # properly clean up, we would not need this - stealable.discard(ts) - self.scheduler.check_idle_saturated( - victim, occ=combined_occupancy(victim) - ) - + stealable.pop(ts, None) if log: self.log(("request", log)) self.count += 1 @@ -560,50 +556,19 @@ def story(self, *keys_or_ts: str | TaskState) -> list: out.append(t) return out - def stealing_objective( - self, ts: TaskState, ws: WorkerState, *, occupancies: dict[WorkerState, float] - ) -> tuple[float, ...]: - """Objective function to determine which worker should get the task - - Minimize expected start time. If a tie then break with data storage. - Notes - ----- - This method is a modified version of Scheduler.worker_objective that accounts - for in-flight requests. It must be kept in sync for work-stealing to work correctly. - - See Also - -------- - Scheduler.worker_objective - """ - occupancy = self._combined_occupancy( - ws, - occupancies=occupancies, - ) / ws.nthreads + self.scheduler.get_comm_cost(ts, ws) - if ts.actor: - return (len(ws.actors), occupancy, ws.nbytes) - else: - return (occupancy, ws.nbytes) - - def _get_thief( - self, - scheduler: SchedulerState, - ts: TaskState, - potential_thieves: set[WorkerState], - *, - occupancies: dict[WorkerState, float], - ) -> WorkerState | None: - valid_workers = scheduler.valid_workers(ts) - if valid_workers is not None: - valid_thieves = potential_thieves & valid_workers - if valid_thieves: - potential_thieves = valid_thieves - elif not ts.loose_restrictions: - return None - return min( - potential_thieves, - key=partial(self.stealing_objective, ts, occupancies=occupancies), - ) +def _get_thief( + scheduler: SchedulerState, ts: TaskState, potential_thieves: set[WorkerState] +) -> WorkerState | None: + valid_workers = scheduler.valid_workers(ts) + if valid_workers is not None: + valid_thieves = potential_thieves & valid_workers + if valid_thieves: + potential_thieves = valid_thieves + elif not ts.loose_restrictions: + return None + objective = scheduler.worker_objective + return min(potential_thieves, key=lambda ws: (objective(ts, ws), ws.name)) fast_tasks = { diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 00cd4be3a7..5ee3cfdd26 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -511,8 +511,7 @@ async def test_queued_paused_new_worker(c, s, a, b): # wait for workers pausing to hit the scheduler await asyncio.sleep(0.01) - assert not s.idle - assert not s.idle_task_count + assert all(ws.processing for ws in s.workers.values()) assert not s.running async with Worker(s.address, nthreads=2) as w: @@ -560,8 +559,7 @@ async def test_queued_paused_unpaused(c, s, a, b, queue): await asyncio.sleep(0.01) assert not s.running - assert not s.idle - assert not s.idle_task_count + assert all(ws.processing for ws in s.workers.values()) # un-pause a.status = Status.running @@ -570,9 +568,7 @@ async def test_queued_paused_unpaused(c, s, a, b, queue): await asyncio.sleep(0.01) if queue: - assert not s.idle # workers should have been (or already were) filled - # If queuing is disabled, all workers might already be saturated when they un-pause. - assert not s.idle_task_count + assert all(ws.processing for ws in s.workers.values()) await wait(final) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 591c0c8de8..597fc7f2d5 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -37,6 +37,7 @@ from distributed.utils import wait_for from distributed.utils_test import ( NO_AMM, + BlockedGatherDep, BlockedGetData, async_poll_for, captured_logger, @@ -750,12 +751,23 @@ def block(*args, event, **kwargs): # Balance several times since stealing might attempt to steal the already executing task # for each saturated worker and will need a chance to correct its mistake - for _ in workers: + previous = steal.count + print("# balance", steal.count) + steal.balance() + await steal.stop() + print(steal.count) + print([msg[1] for msg in s.get_events("stealing")]) + while steal.count != previous: + previous = steal.count + print("# balance", steal.count) steal.balance() # steal.stop() ensures that all in-flight stealing requests have been resolved await steal.stop() - + print(steal.count) + print([msg[1] for msg in s.get_events("stealing")]) await ev.set() + for w in workers: + w.block_gather_dep.set() await c.gather([f for fs in futures_per_worker.values() for f in fs]) result = [ @@ -800,15 +812,9 @@ def block(*args, event, **kwargs): [[0, 0], [0, 0], [0], [0]], id="no one clearly saturated", ), - # NOTE: There is a timing issue that workers may already start executing - # tasks before we call balance, i.e. the workers will reject the - # stealing request and we end up with a different end result. - # Particularly tests with many input tasks are more likely to fail since - # the test setup takes longer and allows the workers more time to - # schedule a task on the threadpool pytest.param( [[4, 2, 2, 2, 2, 1, 1], [4, 2, 1, 1], [], [], []], - [[4, 2, 2, 2], [4, 2, 1, 1], [2], [1], [1]], + [[4, 2, 2, 2, 2], [4, 2, 1], [1], [1], [1]], id="balance multiple saturated workers", ), ], @@ -820,9 +826,12 @@ async def test_balance_(*args, **kwargs): config = { "distributed.scheduler.default-task-durations": {str(i): 1 for i in range(10)} } - gen_cluster(client=True, nthreads=[("", 1)] * len(inp), config=config)( - test_balance_ - )() + gen_cluster( + client=True, + Worker=BlockedGatherDep, + nthreads=[("", 1)] * len(inp), + config=config, + )(test_balance_)() @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2, Worker=Nanny, timeout=60) @@ -1554,12 +1563,8 @@ def test_balance_multiple_to_replica(): def _correct_placement(actual): actual_task_counts = [len(placed) for placed in actual] # FIXME: A better task placement would be even but the current balancing - # logic aborts as soon as a worker is no longer classified as idle - # return actual_task_counts == [ - # 4, - # 4, - # 0, - # ] + # logic only steals when a worker has an idle thread which makes it a + # little less aggressive return actual_task_counts == [ 6, 2, @@ -1758,12 +1763,6 @@ async def _dependency_balance_test_permutation( workers, ) - # Re-evaluate idle/saturated classification to avoid outdated classifications due to - # the initialization order of workers. On a real cluster, this would get constantly - # updated by tasks being added or completing. - for ws in s.workers.values(): - s.check_idle_saturated(ws) - # Balance several since stealing might attempt to steal the already executing task # for each saturated worker and will need a chance to correct its mistake for _ in workers: @@ -1971,7 +1970,7 @@ def block(i: int, in_event: Event, block_event: Event) -> int: async with Worker(s.address, nthreads=1) as b: try: - await async_poll_for(lambda: s.idle, timeout=5) + await async_poll_for(lambda: any(not w.processing for w in s.workers.values()), timeout=5) wsA = s.workers[a.address] wsB = s.workers[b.address] ts = next(iter(wsA.processing)) @@ -2042,7 +2041,7 @@ def block(i: int, in_event: Event, block_event: Event) -> int: async with Worker(s.address, nthreads=1) as b: try: - await async_poll_for(lambda: s.idle, timeout=5) + await async_poll_for(lambda: any(not w.processing for w in s.workers.values()), timeout=5) wsB = s.workers[b.address] From 35ecd5ddd528d8d9f759d2f842bd33e902ee1d14 Mon Sep 17 00:00:00 2001 From: fjetter Date: Tue, 13 May 2025 14:40:49 +0200 Subject: [PATCH 2/3] more fixes --- distributed/stealing.py | 78 ++++++++++++++++++++------------- distributed/tests/test_steal.py | 34 +++++++++----- 2 files changed, 71 insertions(+), 41 deletions(-) diff --git a/distributed/stealing.py b/distributed/stealing.py index 873ca4ec48..e303500c9f 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -131,7 +131,8 @@ async def start(self, scheduler: Any = None) -> None: if "stealing" in self.scheduler.periodic_callbacks: return pc = PeriodicCallback( - callback=self.balance, callback_time=self._callback_time * 1000 + callback=self.balance, # type: ignore + callback_time=self._callback_time * 1000, ) pc.start() self.scheduler.periodic_callbacks["stealing"] = pc @@ -377,7 +378,7 @@ def move_task_confirm( assert ts.processing_on == victim try: - _log_msg = [key, state, victim.name, thief.name, stimulus_id] + _log_msg = [key, state, victim.address, thief.address, stimulus_id] if ( state in _WORKER_STATE_UNDEFINED @@ -417,7 +418,7 @@ def move_task_confirm( raise @log_errors - def balance(self) -> None: + def balance(self) -> bool: s = self.scheduler log = [] start = time() @@ -429,13 +430,12 @@ def balance(self) -> None: i = 0 # Paused and closing workers must never become thieves if len(s.running) < 2: - return - # FIXME: It would be better if this was constant time + return False potential_thieves = { ws for ws in s.running if self._combined_nprocessing(ws) <= ws.nthreads } if not potential_thieves: - return + return False victim: WorkerState | None potential_victims: set[WorkerState] | list[WorkerState] = ( s.running - potential_thieves @@ -451,7 +451,7 @@ def balance(self) -> None: if self._combined_nprocessing(ws) > ws.nthreads ] if not potential_victims: - return + return False potential_victims = sorted( potential_victims, key=lambda x: (self._combined_nprocessing(x), x.name), @@ -467,12 +467,7 @@ def balance(self) -> None: if not stealable or not potential_thieves: continue - # We're iterating in reverse order. This way we're basically - # stealing the lowest priority tasks first which gives us a - # better chance for a successful steal. - # Note that this will disturb global ordering of tasks and - # therefore may contribute to memory pressure - for ts in reversed(list(stealable)): + for ts in list(stealable): if not potential_thieves: break if ( @@ -484,7 +479,11 @@ def balance(self) -> None: stealable.pop(ts, None) continue i += 1 - if not (thief := _get_thief(s, ts, potential_thieves)): + if not ( + thief := self._get_thief( + s, ts, potential_thieves, occupancies=occupancies + ) + ): continue occ_thief = combined_occupancy(thief) @@ -504,9 +503,9 @@ def balance(self) -> None: level, ts.key, cost, - victim.name, + victim.address, occ_victim, - thief.name, + thief.address, occ_thief, stim_id, ) @@ -514,8 +513,8 @@ def balance(self) -> None: self.metrics["request_count_total"][level] += 1 self.metrics["request_cost_total"][level] += cost - if self._combined_nprocessing(thief) >= thief.nthreads: - potential_thieves.discard(thief) + # if self._combined_nprocessing(thief) >= thief.nthreads: + # potential_thieves.discard(thief) # FIXME: move_task_request already implements some logic # for removing ts from stealable. If we made sure to # properly clean up, we would not need this @@ -526,6 +525,7 @@ def balance(self) -> None: stop = time() if s.digests: s.digests["steal-duration"].add(stop - start) + return bool(log) def _combined_occupancy( self, ws: WorkerState, *, occupancies: dict[WorkerState, float] @@ -556,19 +556,37 @@ def story(self, *keys_or_ts: str | TaskState) -> list: out.append(t) return out + def stealing_objective( + self, ts: TaskState, ws: WorkerState, *, occupancies: dict[WorkerState, float] + ) -> tuple[float, ...]: + + occupancy = self._combined_occupancy( + ws, occupancies=occupancies + ) / ws.nthreads + self.scheduler.get_comm_cost(ts, ws) + if ts.actor: + return (len(ws.actors), occupancy, ws.nbytes) + else: + return (occupancy, ws.nbytes) -def _get_thief( - scheduler: SchedulerState, ts: TaskState, potential_thieves: set[WorkerState] -) -> WorkerState | None: - valid_workers = scheduler.valid_workers(ts) - if valid_workers is not None: - valid_thieves = potential_thieves & valid_workers - if valid_thieves: - potential_thieves = valid_thieves - elif not ts.loose_restrictions: - return None - objective = scheduler.worker_objective - return min(potential_thieves, key=lambda ws: (objective(ts, ws), ws.name)) + def _get_thief( + self, + scheduler: SchedulerState, + ts: TaskState, + potential_thieves: set[WorkerState], + *, + occupancies: dict[WorkerState, float], + ) -> WorkerState | None: + valid_workers = scheduler.valid_workers(ts) + if valid_workers is not None: + valid_thieves = potential_thieves & valid_workers + if valid_thieves: + potential_thieves = valid_thieves + elif not ts.loose_restrictions: + return None + return min( + potential_thieves, + key=partial(self.stealing_objective, ts, occupancies=occupancies), + ) fast_tasks = { diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 597fc7f2d5..8bed5ab365 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -538,8 +538,8 @@ async def test_steal_resource_restrictions(c, s, a): while s.workers[b.address].status != Status.running: await asyncio.sleep(0.01) - s.extensions["stealing"].balance() - await asyncio.sleep(0.1) + while s.extensions["stealing"].balance(): + await asyncio.sleep(0.1) assert 20 < len(b.state.tasks) < 80 assert 20 < len(a.state.tasks) < 80 @@ -1557,17 +1557,25 @@ def _correct_placement(actual): def test_balance_multiple_to_replica(): dependencies = {"a": 6} - dependency_placement = [["a"], ["a"], []] - task_placement = [[["a"], ["a"], ["a"], ["a"], ["a"], ["a"], ["a"], ["a"]], [], []] + dependency_placement = [ + ["a"], + ["a"], + [], + ] + task_placement = [ + [["a"], ["a"], ["a"], ["a"], ["a"], ["a"], ["a"], ["a"]], + [], + [], + ] def _correct_placement(actual): actual_task_counts = [len(placed) for placed in actual] - # FIXME: A better task placement would be even but the current balancing - # logic only steals when a worker has an idle thread which makes it a - # little less aggressive + # FIXME: A more homogeneous task placement would be even but the current + # balancing logic only steals when a worker has an idle thread which + # makes it a little less aggressive return actual_task_counts == [ - 6, - 2, + 5, + 3, 0, ] @@ -1970,7 +1978,9 @@ def block(i: int, in_event: Event, block_event: Event) -> int: async with Worker(s.address, nthreads=1) as b: try: - await async_poll_for(lambda: any(not w.processing for w in s.workers.values()), timeout=5) + await async_poll_for( + lambda: any(not w.processing for w in s.workers.values()), timeout=5 + ) wsA = s.workers[a.address] wsB = s.workers[b.address] ts = next(iter(wsA.processing)) @@ -2041,7 +2051,9 @@ def block(i: int, in_event: Event, block_event: Event) -> int: async with Worker(s.address, nthreads=1) as b: try: - await async_poll_for(lambda: any(not w.processing for w in s.workers.values()), timeout=5) + await async_poll_for( + lambda: any(not w.processing for w in s.workers.values()), timeout=5 + ) wsB = s.workers[b.address] From a66f139ce8317df6478e83ab52f4826a334a8588 Mon Sep 17 00:00:00 2001 From: fjetter Date: Tue, 13 May 2025 14:42:43 +0200 Subject: [PATCH 3/3] kill dead tests --- distributed/tests/test_worker_memory.py | 39 ------------------------- 1 file changed, 39 deletions(-) diff --git a/distributed/tests/test_worker_memory.py b/distributed/tests/test_worker_memory.py index 4a222861d8..00ba91be20 100644 --- a/distributed/tests/test_worker_memory.py +++ b/distributed/tests/test_worker_memory.py @@ -1090,45 +1090,6 @@ async def test_deprecated_params(s, name): assert getattr(a.memory_manager, name) == 0.789 -@gen_cluster(config={"distributed.worker.memory.monitor-interval": "10ms"}) -async def test_pause_while_idle(s, a, b): - sa = s.workers[a.address] - assert a.address in s.idle - assert sa in s.running - - a.monitor.get_process_memory = lambda: 2**40 - await async_poll_for(lambda: sa.status == Status.paused, timeout=5) - assert a.address not in s.idle - assert sa not in s.running - - a.monitor.get_process_memory = lambda: 0 - await async_poll_for(lambda: sa.status == Status.running, timeout=5) - assert a.address in s.idle - assert sa in s.running - - -@gen_cluster(client=True, config={"distributed.worker.memory.monitor-interval": "10ms"}) -async def test_pause_while_saturated(c, s, a, b): - sa = s.workers[a.address] - ev = Event() - futs = c.map(lambda i, ev: ev.wait(), range(3), ev=ev, workers=[a.address]) - await async_poll_for(lambda: len(a.state.tasks) == 3, timeout=5) - assert sa in s.saturated - assert sa in s.running - - a.monitor.get_process_memory = lambda: 2**40 - await async_poll_for(lambda: sa.status == Status.paused, timeout=5) - assert sa not in s.saturated - assert sa not in s.running - - a.monitor.get_process_memory = lambda: 0 - await async_poll_for(lambda: sa.status == Status.running, timeout=5) - assert sa in s.saturated - assert sa in s.running - - await ev.set() - - @gen_cluster(nthreads=[]) async def test_worker_log_memory_limit_too_high(s): async with Worker(s.address, memory_limit="1 PB") as worker: