diff --git a/tests/test_driver.py b/tests/test_driver.py index 20e50ea38..5c7cf15dc 100755 --- a/tests/test_driver.py +++ b/tests/test_driver.py @@ -193,7 +193,7 @@ async def main(): semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None - async def run(f, arg): + async def run(f, arg) -> TestResult: # TODO(python>3.8): use "async with" if semaphore is not None: await semaphore.acquire() @@ -205,7 +205,7 @@ async def main(): if semaphore is not None: semaphore.release() - tasks = [asyncio.create_task(run(f, arg), name=arg) for f, arg in files] + tasks = [create_task(run(f, arg), name=arg) for f, arg in files] while tasks: done, tasks = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) for task in done: @@ -276,6 +276,14 @@ class TestPass: TestResult = Union[TestSkip, TestFail, TestPass] +# TODO(python>3.8): use asyncio.create_task +def create_task(coro, name: str) -> asyncio.Task: + task = asyncio.Task(coro) + if sys.version_info >= (3, 8): + task.set_name(name) + return task + + async def run_test( tmp_root: Path, test_file_path: str,