Skip to content

Commit 5819d9a

Browse files
committed
Fix flaky test_broken_worker
1 parent dfc4358 commit 5819d9a

File tree

3 files changed

+21
-19
lines changed

3 files changed

+21
-19
lines changed

.github/workflows/tests.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,8 @@ jobs:
203203
set -o pipefail
204204
mkdir reports
205205
206-
pytest distributed \
207-
-m "not avoid_ci and ${{ matrix.partition }}" --runslow \
206+
pytest distributed/deploy/tests/test_spec_cluster.py \
207+
--count 10 --runslow \
208208
--leaks=fds,processes,threads \
209209
--junitxml reports/pytest.xml -o junit_suite_name=$TEST_ID \
210210
--cov=distributed --cov-report=xml \

distributed/deploy/spec.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -379,13 +379,17 @@ async def _correct_state_internal(self) -> None:
379379
self._created.add(worker)
380380
workers.append(worker)
381381
if workers:
382-
await asyncio.wait(
383-
[asyncio.create_task(_wrap_awaitable(w)) for w in workers]
384-
)
382+
worker_futs = [asyncio.ensure_future(w) for w in workers]
383+
await asyncio.wait(worker_futs)
384+
self.workers.update(dict(zip(to_open, workers)))
385385
for w in workers:
386386
w._cluster = weakref.ref(self)
387+
# Collect exceptions from failed workers. This must happen after all
388+
# *other* workers have finished initialising, so that we can have a
389+
# proper teardown.
390+
await asyncio.gather(*worker_futs)
391+
for w in workers:
387392
await w # for tornado gen.coroutine support
388-
self.workers.update(dict(zip(to_open, workers)))
389393

390394
def _update_worker_status(self, op, msg):
391395
if op == "remove":
@@ -467,10 +471,14 @@ async def _close(self):
467471
await super()._close()
468472

469473
async def __aenter__(self):
470-
await self
471-
await self._correct_state()
472-
assert self.status == Status.running
473-
return self
474+
try:
475+
await self
476+
await self._correct_state()
477+
assert self.status == Status.running
478+
return self
479+
except Exception:
480+
await self.close()
481+
raise
474482

475483
def _threads_per_worker(self) -> int:
476484
"""Return the number of threads per worker for new workers"""

distributed/deploy/tests/test_spec_cluster.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,6 @@ async def test_restart():
207207
await asyncio.sleep(0.01)
208208

209209

210-
@pytest.mark.skipif(WINDOWS, reason="HTTP Server doesn't close out")
211210
@gen_test()
212211
async def test_broken_worker():
213212
class BrokenWorkerException(Exception):
@@ -216,7 +215,6 @@ class BrokenWorkerException(Exception):
216215
class BrokenWorker(Worker):
217216
def __await__(self):
218217
async def _():
219-
self.status = Status.closed
220218
raise BrokenWorkerException("Worker Broken")
221219

222220
return _().__await__()
@@ -226,13 +224,9 @@ async def _():
226224
workers={"good": {"cls": Worker}, "bad": {"cls": BrokenWorker}},
227225
scheduler=scheduler,
228226
)
229-
try:
230-
with pytest.raises(BrokenWorkerException, match=r"Worker Broken"):
231-
async with cluster:
232-
pass
233-
finally:
234-
# FIXME: SpecCluster leaks if SpecCluster.__aenter__ raises
235-
await cluster.close()
227+
with pytest.raises(BrokenWorkerException, match=r"Worker Broken"):
228+
async with cluster:
229+
pass
236230

237231

238232
@pytest.mark.skipif(WINDOWS, reason="HTTP Server doesn't close out")

0 commit comments

Comments
 (0)