diff --git a/distributed/scheduler.py b/distributed/scheduler.py index f1e2c1f3e3..46b89b26a6 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1770,11 +1770,8 @@ def __init__( self.clients["fire-and-forget"] = ClientState("fire-and-forget") self.extensions = {} self.host_info = host_info - self.idle = SortedDict() - self.idle_task_count = set() self.n_tasks = 0 self.resources = resources - self.saturated = set() self.tasks = tasks self.replicated_tasks = { ts for ts in self.tasks.values() if len(ts.who_has or ()) > 1 @@ -1865,7 +1862,6 @@ def __pdict__(self) -> dict[str, Any]: return { "bandwidth": self.bandwidth, "resources": self.resources, - "saturated": self.saturated, "unrunnable": self.unrunnable, "queued": self.queued, "n_tasks": self.n_tasks, @@ -1879,7 +1875,6 @@ def __pdict__(self) -> dict[str, Any]: "extensions": self.extensions, "clients": self.clients, "workers": self.workers, - "idle": self.idle, "host_info": self.host_info, } @@ -2310,7 +2305,7 @@ def decide_worker_rootish_queuing_disabled( # See root-ish-ness note below in `decide_worker_rootish_queuing_enabled` assert math.isinf(self.WORKER_SATURATION) or not ts._queueable - pool = self.idle.values() if self.idle else self.running + pool = self.running if not pool: return None @@ -2375,22 +2370,16 @@ def decide_worker_rootish_queuing_enabled(self) -> WorkerState | None: # then add that assertion here (and actually pass in the task). assert not math.isinf(self.WORKER_SATURATION) - if not self.idle_task_count: - # All workers busy? Task gets/stays queued. + if not self.running: return None # Just pick the least busy worker. # NOTE: this will lead to worst-case scheduling with regards to co-assignment. - ws = min( - self.idle_task_count, - key=lambda ws: len(ws.processing) / ws.nthreads, - ) + ws = min(self.running, key=lambda ws: len(ws.processing) / ws.nthreads) + if _worker_full(ws, self.WORKER_SATURATION): + return None if self.validate: assert self.workers.get(ws.address) is ws - assert not _worker_full(ws, self.WORKER_SATURATION), ( - ws, - _task_slots_available(ws, self.WORKER_SATURATION), - ) assert ws in self.running, (ws, self.running) return ws @@ -2434,7 +2423,7 @@ def decide_worker_non_rootish(self, ts: TaskState) -> WorkerState | None: # dependencies, but its group is also smaller than the cluster. # Fastpath when there are no related tasks or restrictions - worker_pool = self.idle or self.workers + worker_pool = self.workers # FIXME idle and workers are SortedDict's declared as dicts # because sortedcontainers is not annotated wp_vals = cast("Sequence[WorkerState]", worker_pool.values()) @@ -2927,7 +2916,6 @@ def _transition_waiting_queued(self, key: Key, stimulus_id: str) -> RecsMsgs: ts = self.tasks[key] if self.validate: - assert not self.idle_task_count, (ts, self.idle_task_count) self._validate_ready(ts) ts.state = "queued" @@ -3158,63 +3146,6 @@ def is_rootish(self, ts: TaskState) -> bool: and sum(map(len, tg.dependencies)) < self.rootish_tg_dependencies_threshold ) - def check_idle_saturated(self, ws: WorkerState, occ: float = -1.0) -> None: - """Update the status of the idle and saturated state - - The scheduler keeps track of workers that are .. - - - Saturated: have enough work to stay busy - - Idle: do not have enough work to stay busy - - They are considered saturated if they both have enough tasks to occupy - all of their threads, and if the expected runtime of those tasks is - large enough. - - If ``distributed.scheduler.worker-saturation`` is not ``inf`` - (scheduler-side queuing is enabled), they are considered idle - if they have fewer tasks processing than the ``worker-saturation`` - threshold dictates. - - Otherwise, they are considered idle if they have fewer tasks processing - than threads, or if their tasks' total expected runtime is less than half - the expected runtime of the same number of average tasks. - - This is useful for load balancing and adaptivity. - """ - if self.total_nthreads == 0 or ws.status == Status.closed: - return - if occ < 0: - occ = ws.occupancy - - p = len(ws.processing) - - self.saturated.discard(ws) - if ws.status != Status.running: - self.idle.pop(ws.address, None) - elif self.is_unoccupied(ws, occ, p): - self.idle[ws.address] = ws - else: - self.idle.pop(ws.address, None) - nc = ws.nthreads - if p > nc: - pending = occ * (p - nc) / (p * nc) - if 0.4 < pending > 1.9 * (self.total_occupancy / self.total_nthreads): - self.saturated.add(ws) - - if not _worker_full(ws, self.WORKER_SATURATION) and ws.status == Status.running: - self.idle_task_count.add(ws) - else: - self.idle_task_count.discard(ws) - - def is_unoccupied( - self, ws: WorkerState, occupancy: float, nprocessing: int - ) -> bool: - nthreads = ws.nthreads - return ( - nprocessing < nthreads - or occupancy < nthreads * (self.total_occupancy / self.total_nthreads) / 2 - ) - def get_comm_cost(self, ts: TaskState, ws: WorkerState) -> float: """ Get the estimated communication cost (in s.) to compute the task @@ -3402,7 +3333,6 @@ def _add_to_processing( ts.processing_on = ws ts.state = "processing" self.acquire_resources(ts, ws) - self.check_idle_saturated(ws) self.n_tasks += 1 if ts.actor: @@ -3468,7 +3398,6 @@ def _exit_processing_common(self, ts: TaskState) -> WorkerState | None: if self.workers.get(ws.address) is not ws: # may have been removed return None - self.check_idle_saturated(ws) self.release_resources(ts, ws) return ws @@ -4606,10 +4535,6 @@ async def add_worker( metrics=metrics, ) - # Do not need to adjust self.total_occupancy as self.occupancy[ws] cannot - # exist before this. - self.check_idle_saturated(ws) - self.stream_comms[address] = BatchedSend(interval="5ms", loop=self.loop) awaitables = [] @@ -5227,13 +5152,11 @@ def stimulus_queue_slots_maybe_opened(self, *, stimulus_id: str) -> None: so any tasks that became runnable are already in ``processing``. Otherwise, overproduction can occur if queued tasks get scheduled before downstream tasks. - Must be called after `check_idle_saturated`; i.e. `idle_task_count` must be up to date. """ if not self.queued: return slots_available = sum( - _task_slots_available(ws, self.WORKER_SATURATION) - for ws in self.idle_task_count + _task_slots_available(ws, self.WORKER_SATURATION) for ws in self.running ) if slots_available == 0: return @@ -5466,9 +5389,6 @@ async def remove_worker( self.rpc.remove(address) del self.stream_comms[address] del self.aliases[ws.name] - self.idle.pop(ws.address, None) - self.idle_task_count.discard(ws) - self.saturated.discard(ws) del self.workers[address] self._workers_removed_total += 1 ws.status = Status.closed @@ -5818,23 +5738,6 @@ def validate_state(self, allow_overlap: bool = False) -> None: if not (set(self.workers) == set(self.stream_comms)): raise ValueError("Workers not the same in all collections") - assert self.running.issuperset(self.idle.values()), ( - self.running.copy(), - set(self.idle.values()), - ) - assert self.running.issuperset(self.idle_task_count), ( - self.running.copy(), - self.idle_task_count.copy(), - ) - assert self.running.issuperset(self.saturated), ( - self.running.copy(), - self.saturated.copy(), - ) - assert self.saturated.isdisjoint(self.idle.values()), ( - self.saturated.copy(), - set(self.idle.values()), - ) - task_prefix_counts: defaultdict[str, int] = defaultdict(int) for w, ws in self.workers.items(): assert isinstance(w, str), (type(w), w) @@ -5845,14 +5748,10 @@ def validate_state(self, allow_overlap: bool = False) -> None: assert ws in self.running else: assert ws not in self.running - assert ws.address not in self.idle - assert ws not in self.saturated assert ws.long_running.issubset(ws.processing) if not ws.processing: assert not ws.occupancy - if ws.status == Status.running: - assert ws.address in self.idle assert not ws.needs_what.keys() & ws.has_what actual_needs_what: defaultdict[TaskState, int] = defaultdict(int) for ts in ws.processing: @@ -6136,7 +6035,6 @@ def handle_long_running( ts.prefix.duration_average = (old_duration + compute_duration) / 2 ws.add_to_long_running(ts) - self.check_idle_saturated(ws) self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id) @@ -6164,16 +6062,12 @@ def handle_worker_status_change( if ws.status == Status.running: self.running.add(ws) - self.check_idle_saturated(ws) self.transitions( self.bulk_schedule_unrunnable_after_adding_worker(ws), stimulus_id ) self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id) else: self.running.discard(ws) - self.idle.pop(ws.address, None) - self.idle_task_count.discard(ws) - self.saturated.discard(ws) self._refresh_no_workers_since() def handle_request_refresh_who_has( diff --git a/distributed/stealing.py b/distributed/stealing.py index e3c3ace81c..e303500c9f 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -7,7 +7,7 @@ from functools import partial from math import log2 from time import time -from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast +from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypedDict, cast from tlz import topk @@ -38,7 +38,6 @@ logger = logging.getLogger(__name__) - LOG_PDB = dask.config.get("distributed.admin.pdb-on-err") _WORKER_STATE_CONFIRM = { @@ -71,7 +70,7 @@ class InFlightInfo(TypedDict): class WorkStealing(SchedulerPlugin): scheduler: Scheduler # {worker: ({ task states for level 0}, ..., {task states for level 14})} - stealable: dict[str, tuple[set[TaskState], ...]] + stealable: dict[str, tuple[dict[TaskState, Literal[None]], ...]] # { task state: (worker, level) } key_stealable: dict[TaskState, tuple[str, int]] # (multiplier for level 0, ... multiplier for level 14) @@ -132,7 +131,8 @@ async def start(self, scheduler: Any = None) -> None: if "stealing" in self.scheduler.periodic_callbacks: return pc = PeriodicCallback( - callback=self.balance, callback_time=self._callback_time * 1000 + callback=self.balance, # type: ignore + callback_time=self._callback_time * 1000, ) pc.start() self.scheduler.periodic_callbacks["stealing"] = pc @@ -163,7 +163,7 @@ def log(self, msg: Any) -> None: return self.scheduler.log_event("stealing", msg) def add_worker(self, scheduler: Any = None, worker: Any = None) -> None: - self.stealable[worker] = tuple(set() for _ in range(15)) + self.stealable[worker] = tuple(dict() for _ in range(15)) def remove_worker(self, scheduler: Scheduler, worker: str, **kwargs: Any) -> None: del self.stealable[worker] @@ -245,7 +245,7 @@ def put_key_in_stealable(self, ts: TaskState) -> None: assert ts.processing_on ws = ts.processing_on worker = ws.address - self.stealable[worker][level].add(ts) + self.stealable[worker][level][ts] = None self.key_stealable[ts] = (worker, level) if duration == ts.prefix.duration_average: @@ -262,7 +262,7 @@ def remove_key_from_stealable(self, ts: TaskState) -> None: return worker, level = result - self.stealable[worker][level].discard(ts) + self.stealable[worker][level].pop(ts, None) def steal_time_ratio(self, ts: TaskState) -> tuple[float, int] | tuple[None, None]: """The compute to communication time ratio of a key @@ -416,12 +416,9 @@ def move_task_confirm( pdb.set_trace() raise - finally: - self.scheduler.check_idle_saturated(thief) - self.scheduler.check_idle_saturated(victim) @log_errors - def balance(self) -> None: + def balance(self) -> bool: s = self.scheduler log = [] start = time() @@ -432,26 +429,34 @@ def balance(self) -> None: i = 0 # Paused and closing workers must never become thieves - potential_thieves = set(s.idle.values()) - if not potential_thieves or len(potential_thieves) == len(s.workers): - return + if len(s.running) < 2: + return False + potential_thieves = { + ws for ws in s.running if self._combined_nprocessing(ws) <= ws.nthreads + } + if not potential_thieves: + return False victim: WorkerState | None - potential_victims: set[WorkerState] | list[WorkerState] = s.saturated + potential_victims: set[WorkerState] | list[WorkerState] = ( + s.running - potential_thieves + ) + potential_victims = topk( + 10, + potential_victims, + key=lambda x: (self._combined_nprocessing(x), x.name), + ) + potential_victims = [ + ws + for ws in potential_victims + if self._combined_nprocessing(ws) > ws.nthreads + ] if not potential_victims: - potential_victims = topk(10, s.workers.values(), key=combined_occupancy) - potential_victims = [ - ws - for ws in potential_victims - if combined_occupancy(ws) > 0.2 - and self._combined_nprocessing(ws) > ws.nthreads - and ws not in potential_thieves - ] - if not potential_victims: - return - if len(potential_victims) < 20: - potential_victims = sorted( - potential_victims, key=combined_occupancy, reverse=True - ) + return False + potential_victims = sorted( + potential_victims, + key=lambda x: (self._combined_nprocessing(x), x.name), + reverse=True, + ) assert potential_victims assert potential_thieves for level, _ in enumerate(self.cost_multipliers): @@ -471,7 +476,7 @@ def balance(self) -> None: or ts not in victim.processing ): # FIXME: Instead of discarding here, clean up stealable properly - stealable.discard(ts) + stealable.pop(ts, None) continue i += 1 if not ( @@ -490,7 +495,7 @@ def balance(self) -> None: occ_thief + comm_cost_thief + compute <= occ_victim - (comm_cost_victim + compute) / 2 ): - self.move_task_request(ts, victim, thief) + stim_id = self.move_task_request(ts, victim, thief) cost = compute + comm_cost_victim log.append( ( @@ -502,34 +507,25 @@ def balance(self) -> None: occ_victim, thief.address, occ_thief, + stim_id, ) ) self.metrics["request_count_total"][level] += 1 self.metrics["request_cost_total"][level] += cost - occ_thief = combined_occupancy(thief) - nproc_thief = self._combined_nprocessing(thief) - - # FIXME: In the worst case, the victim may have 3x the amount of work - # of the thief when this aborts balancing. - if not self.scheduler.is_unoccupied( - thief, occ_thief, nproc_thief - ): - potential_thieves.discard(thief) + # if self._combined_nprocessing(thief) >= thief.nthreads: + # potential_thieves.discard(thief) # FIXME: move_task_request already implements some logic # for removing ts from stealable. If we made sure to # properly clean up, we would not need this - stealable.discard(ts) - self.scheduler.check_idle_saturated( - victim, occ=combined_occupancy(victim) - ) - + stealable.pop(ts, None) if log: self.log(("request", log)) self.count += 1 stop = time() if s.digests: s.digests["steal-duration"].add(stop - start) + return bool(log) def _combined_occupancy( self, ws: WorkerState, *, occupancies: dict[WorkerState, float] @@ -563,22 +559,9 @@ def story(self, *keys_or_ts: str | TaskState) -> list: def stealing_objective( self, ts: TaskState, ws: WorkerState, *, occupancies: dict[WorkerState, float] ) -> tuple[float, ...]: - """Objective function to determine which worker should get the task - - Minimize expected start time. If a tie then break with data storage. - Notes - ----- - This method is a modified version of Scheduler.worker_objective that accounts - for in-flight requests. It must be kept in sync for work-stealing to work correctly. - - See Also - -------- - Scheduler.worker_objective - """ occupancy = self._combined_occupancy( - ws, - occupancies=occupancies, + ws, occupancies=occupancies ) / ws.nthreads + self.scheduler.get_comm_cost(ts, ws) if ts.actor: return (len(ws.actors), occupancy, ws.nbytes) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 00cd4be3a7..5ee3cfdd26 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -511,8 +511,7 @@ async def test_queued_paused_new_worker(c, s, a, b): # wait for workers pausing to hit the scheduler await asyncio.sleep(0.01) - assert not s.idle - assert not s.idle_task_count + assert all(ws.processing for ws in s.workers.values()) assert not s.running async with Worker(s.address, nthreads=2) as w: @@ -560,8 +559,7 @@ async def test_queued_paused_unpaused(c, s, a, b, queue): await asyncio.sleep(0.01) assert not s.running - assert not s.idle - assert not s.idle_task_count + assert all(ws.processing for ws in s.workers.values()) # un-pause a.status = Status.running @@ -570,9 +568,7 @@ async def test_queued_paused_unpaused(c, s, a, b, queue): await asyncio.sleep(0.01) if queue: - assert not s.idle # workers should have been (or already were) filled - # If queuing is disabled, all workers might already be saturated when they un-pause. - assert not s.idle_task_count + assert all(ws.processing for ws in s.workers.values()) await wait(final) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 591c0c8de8..8bed5ab365 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -37,6 +37,7 @@ from distributed.utils import wait_for from distributed.utils_test import ( NO_AMM, + BlockedGatherDep, BlockedGetData, async_poll_for, captured_logger, @@ -537,8 +538,8 @@ async def test_steal_resource_restrictions(c, s, a): while s.workers[b.address].status != Status.running: await asyncio.sleep(0.01) - s.extensions["stealing"].balance() - await asyncio.sleep(0.1) + while s.extensions["stealing"].balance(): + await asyncio.sleep(0.1) assert 20 < len(b.state.tasks) < 80 assert 20 < len(a.state.tasks) < 80 @@ -750,12 +751,23 @@ def block(*args, event, **kwargs): # Balance several times since stealing might attempt to steal the already executing task # for each saturated worker and will need a chance to correct its mistake - for _ in workers: + previous = steal.count + print("# balance", steal.count) + steal.balance() + await steal.stop() + print(steal.count) + print([msg[1] for msg in s.get_events("stealing")]) + while steal.count != previous: + previous = steal.count + print("# balance", steal.count) steal.balance() # steal.stop() ensures that all in-flight stealing requests have been resolved await steal.stop() - + print(steal.count) + print([msg[1] for msg in s.get_events("stealing")]) await ev.set() + for w in workers: + w.block_gather_dep.set() await c.gather([f for fs in futures_per_worker.values() for f in fs]) result = [ @@ -800,15 +812,9 @@ def block(*args, event, **kwargs): [[0, 0], [0, 0], [0], [0]], id="no one clearly saturated", ), - # NOTE: There is a timing issue that workers may already start executing - # tasks before we call balance, i.e. the workers will reject the - # stealing request and we end up with a different end result. - # Particularly tests with many input tasks are more likely to fail since - # the test setup takes longer and allows the workers more time to - # schedule a task on the threadpool pytest.param( [[4, 2, 2, 2, 2, 1, 1], [4, 2, 1, 1], [], [], []], - [[4, 2, 2, 2], [4, 2, 1, 1], [2], [1], [1]], + [[4, 2, 2, 2, 2], [4, 2, 1], [1], [1], [1]], id="balance multiple saturated workers", ), ], @@ -820,9 +826,12 @@ async def test_balance_(*args, **kwargs): config = { "distributed.scheduler.default-task-durations": {str(i): 1 for i in range(10)} } - gen_cluster(client=True, nthreads=[("", 1)] * len(inp), config=config)( - test_balance_ - )() + gen_cluster( + client=True, + Worker=BlockedGatherDep, + nthreads=[("", 1)] * len(inp), + config=config, + )(test_balance_)() @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2, Worker=Nanny, timeout=60) @@ -1548,21 +1557,25 @@ def _correct_placement(actual): def test_balance_multiple_to_replica(): dependencies = {"a": 6} - dependency_placement = [["a"], ["a"], []] - task_placement = [[["a"], ["a"], ["a"], ["a"], ["a"], ["a"], ["a"], ["a"]], [], []] + dependency_placement = [ + ["a"], + ["a"], + [], + ] + task_placement = [ + [["a"], ["a"], ["a"], ["a"], ["a"], ["a"], ["a"], ["a"]], + [], + [], + ] def _correct_placement(actual): actual_task_counts = [len(placed) for placed in actual] - # FIXME: A better task placement would be even but the current balancing - # logic aborts as soon as a worker is no longer classified as idle - # return actual_task_counts == [ - # 4, - # 4, - # 0, - # ] + # FIXME: A more homogeneous task placement would be even but the current + # balancing logic only steals when a worker has an idle thread which + # makes it a little less aggressive return actual_task_counts == [ - 6, - 2, + 5, + 3, 0, ] @@ -1758,12 +1771,6 @@ async def _dependency_balance_test_permutation( workers, ) - # Re-evaluate idle/saturated classification to avoid outdated classifications due to - # the initialization order of workers. On a real cluster, this would get constantly - # updated by tasks being added or completing. - for ws in s.workers.values(): - s.check_idle_saturated(ws) - # Balance several since stealing might attempt to steal the already executing task # for each saturated worker and will need a chance to correct its mistake for _ in workers: @@ -1971,7 +1978,9 @@ def block(i: int, in_event: Event, block_event: Event) -> int: async with Worker(s.address, nthreads=1) as b: try: - await async_poll_for(lambda: s.idle, timeout=5) + await async_poll_for( + lambda: any(not w.processing for w in s.workers.values()), timeout=5 + ) wsA = s.workers[a.address] wsB = s.workers[b.address] ts = next(iter(wsA.processing)) @@ -2042,7 +2051,9 @@ def block(i: int, in_event: Event, block_event: Event) -> int: async with Worker(s.address, nthreads=1) as b: try: - await async_poll_for(lambda: s.idle, timeout=5) + await async_poll_for( + lambda: any(not w.processing for w in s.workers.values()), timeout=5 + ) wsB = s.workers[b.address] diff --git a/distributed/tests/test_worker_memory.py b/distributed/tests/test_worker_memory.py index 4a222861d8..00ba91be20 100644 --- a/distributed/tests/test_worker_memory.py +++ b/distributed/tests/test_worker_memory.py @@ -1090,45 +1090,6 @@ async def test_deprecated_params(s, name): assert getattr(a.memory_manager, name) == 0.789 -@gen_cluster(config={"distributed.worker.memory.monitor-interval": "10ms"}) -async def test_pause_while_idle(s, a, b): - sa = s.workers[a.address] - assert a.address in s.idle - assert sa in s.running - - a.monitor.get_process_memory = lambda: 2**40 - await async_poll_for(lambda: sa.status == Status.paused, timeout=5) - assert a.address not in s.idle - assert sa not in s.running - - a.monitor.get_process_memory = lambda: 0 - await async_poll_for(lambda: sa.status == Status.running, timeout=5) - assert a.address in s.idle - assert sa in s.running - - -@gen_cluster(client=True, config={"distributed.worker.memory.monitor-interval": "10ms"}) -async def test_pause_while_saturated(c, s, a, b): - sa = s.workers[a.address] - ev = Event() - futs = c.map(lambda i, ev: ev.wait(), range(3), ev=ev, workers=[a.address]) - await async_poll_for(lambda: len(a.state.tasks) == 3, timeout=5) - assert sa in s.saturated - assert sa in s.running - - a.monitor.get_process_memory = lambda: 2**40 - await async_poll_for(lambda: sa.status == Status.paused, timeout=5) - assert sa not in s.saturated - assert sa not in s.running - - a.monitor.get_process_memory = lambda: 0 - await async_poll_for(lambda: sa.status == Status.running, timeout=5) - assert sa in s.saturated - assert sa in s.running - - await ev.set() - - @gen_cluster(nthreads=[]) async def test_worker_log_memory_limit_too_high(s): async with Worker(s.address, memory_limit="1 PB") as worker: