Skip to content

bpo-46752: Slight improvements to TaskGroup API #31398

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 4 additions & 18 deletions Lib/asyncio/taskgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@

__all__ = ["TaskGroup"]

import itertools
import textwrap
import traceback
import types
import weakref

from . import events
Expand All @@ -15,12 +11,7 @@

class TaskGroup:

def __init__(self, *, name=None):
if name is None:
self._name = f'tg-{_name_counter()}'
else:
self._name = str(name)

def __init__(self):
self._entered = False
self._exiting = False
self._aborting = False
Expand All @@ -33,11 +24,8 @@ def __init__(self, *, name=None):
self._base_error = None
self._on_completed_fut = None

def get_name(self):
return self._name

def __repr__(self):
msg = f'<TaskGroup {self._name!r}'
msg = f'<TaskGroup'
if self._tasks:
msg += f' tasks:{len(self._tasks)}'
if self._unfinished_tasks:
Expand Down Expand Up @@ -152,12 +140,13 @@ async def __aexit__(self, et, exc, tb):
me = BaseExceptionGroup('unhandled errors in a TaskGroup', errors)
raise me from None

def create_task(self, coro):
def create_task(self, coro, *, name=None):
if not self._entered:
raise RuntimeError(f"TaskGroup {self!r} has not been entered")
if self._exiting and self._unfinished_tasks == 0:
raise RuntimeError(f"TaskGroup {self!r} is finished")
task = self._loop.create_task(coro)
tasks._set_task_name(task, name)
task.add_done_callback(self._on_task_done)
self._unfinished_tasks += 1
self._tasks.add(task)
Expand Down Expand Up @@ -230,6 +219,3 @@ def _on_task_done(self, task):
# # after TaskGroup is finished.
self._parent_cancel_requested = True
self._parent_task.cancel()


_name_counter = itertools.count(1).__next__
19 changes: 13 additions & 6 deletions Lib/test/test_asyncio/test_taskgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,10 +368,10 @@ async def crash_after(t):
raise ValueError(t)

async def runner():
async with taskgroups.TaskGroup(name='g1') as g1:
async with taskgroups.TaskGroup() as g1:
g1.create_task(crash_after(0.1))

async with taskgroups.TaskGroup(name='g2') as g2:
async with taskgroups.TaskGroup() as g2:
g2.create_task(crash_after(0.2))

r = asyncio.create_task(runner())
Expand All @@ -387,10 +387,10 @@ async def crash_after(t):
raise ValueError(t)

async def runner():
async with taskgroups.TaskGroup(name='g1') as g1:
async with taskgroups.TaskGroup() as g1:
g1.create_task(crash_after(10))

async with taskgroups.TaskGroup(name='g2') as g2:
async with taskgroups.TaskGroup() as g2:
g2.create_task(crash_after(0.1))

r = asyncio.create_task(runner())
Expand All @@ -407,7 +407,7 @@ async def crash_soon():
1 / 0

async def runner():
async with taskgroups.TaskGroup(name='g1') as g1:
async with taskgroups.TaskGroup() as g1:
g1.create_task(crash_soon())
try:
await asyncio.sleep(10)
Expand All @@ -430,7 +430,7 @@ async def crash_soon():
1 / 0

async def nested_runner():
async with taskgroups.TaskGroup(name='g1') as g1:
async with taskgroups.TaskGroup() as g1:
g1.create_task(crash_soon())
try:
await asyncio.sleep(10)
Expand Down Expand Up @@ -692,3 +692,10 @@ async def runner():

self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
self.assertGreaterEqual(nhydras, 10)

async def test_taskgroup_task_name(self):
async def coro():
await asyncio.sleep(0)
async with taskgroups.TaskGroup() as g:
t = g.create_task(coro(), name="yolo")
self.assertEqual(t.get_name(), "yolo")