File tree 3 files changed +21
-19
lines changed 3 files changed +21
-19
lines changed Original file line number Diff line number Diff line change @@ -203,8 +203,8 @@ jobs:
203
203
set -o pipefail
204
204
mkdir reports
205
205
206
- pytest distributed \
207
- -m "not avoid_ci and ${{ matrix.partition }}" --runslow \
206
+ pytest distributed/deploy/tests/test_spec_cluster.py \
207
+ --count 10 --runslow \
208
208
--leaks=fds,processes,threads \
209
209
--junitxml reports/pytest.xml -o junit_suite_name=$TEST_ID \
210
210
--cov=distributed --cov-report=xml \
Original file line number Diff line number Diff line change @@ -379,13 +379,17 @@ async def _correct_state_internal(self) -> None:
379
379
self ._created .add (worker )
380
380
workers .append (worker )
381
381
if workers :
382
- await asyncio .wait (
383
- [ asyncio .create_task ( _wrap_awaitable ( w )) for w in workers ]
384
- )
382
+ worker_futs = [ asyncio .ensure_future ( w ) for w in workers ]
383
+ await asyncio .wait ( worker_futs )
384
+ self . workers . update ( dict ( zip ( to_open , workers )) )
385
385
for w in workers :
386
386
w ._cluster = weakref .ref (self )
387
+ # Collect exceptions from failed workers. This must happen after all
388
+ # *other* workers have finished initialising, so that we can have a
389
+ # proper teardown.
390
+ await asyncio .gather (* worker_futs )
391
+ for w in workers :
387
392
await w # for tornado gen.coroutine support
388
- self .workers .update (dict (zip (to_open , workers )))
389
393
390
394
def _update_worker_status (self , op , msg ):
391
395
if op == "remove" :
@@ -467,10 +471,14 @@ async def _close(self):
467
471
await super ()._close ()
468
472
469
473
async def __aenter__ (self ):
470
- await self
471
- await self ._correct_state ()
472
- assert self .status == Status .running
473
- return self
474
+ try :
475
+ await self
476
+ await self ._correct_state ()
477
+ assert self .status == Status .running
478
+ return self
479
+ except Exception :
480
+ await self .close ()
481
+ raise
474
482
475
483
def _threads_per_worker (self ) -> int :
476
484
"""Return the number of threads per worker for new workers"""
Original file line number Diff line number Diff line change @@ -207,7 +207,6 @@ async def test_restart():
207
207
await asyncio .sleep (0.01 )
208
208
209
209
210
- @pytest .mark .skipif (WINDOWS , reason = "HTTP Server doesn't close out" )
211
210
@gen_test ()
212
211
async def test_broken_worker ():
213
212
class BrokenWorkerException (Exception ):
@@ -216,7 +215,6 @@ class BrokenWorkerException(Exception):
216
215
class BrokenWorker (Worker ):
217
216
def __await__ (self ):
218
217
async def _ ():
219
- self .status = Status .closed
220
218
raise BrokenWorkerException ("Worker Broken" )
221
219
222
220
return _ ().__await__ ()
@@ -226,13 +224,9 @@ async def _():
226
224
workers = {"good" : {"cls" : Worker }, "bad" : {"cls" : BrokenWorker }},
227
225
scheduler = scheduler ,
228
226
)
229
- try :
230
- with pytest .raises (BrokenWorkerException , match = r"Worker Broken" ):
231
- async with cluster :
232
- pass
233
- finally :
234
- # FIXME: SpecCluster leaks if SpecCluster.__aenter__ raises
235
- await cluster .close ()
227
+ with pytest .raises (BrokenWorkerException , match = r"Worker Broken" ):
228
+ async with cluster :
229
+ pass
236
230
237
231
238
232
@pytest .mark .skipif (WINDOWS , reason = "HTTP Server doesn't close out" )
You can’t perform that action at this time.
0 commit comments