diff --git a/tests/test_driver.py b/tests/test_driver.py index dd33c08e6..4688bd231 100755 --- a/tests/test_driver.py +++ b/tests/test_driver.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import argparse import asyncio +import contextlib from dataclasses import dataclass from datetime import datetime import os @@ -115,10 +116,20 @@ async def main(): argparser = argparse.ArgumentParser( description="test_driver: Run fish's test suite" ) + argparser.add_argument( + "--max-concurrency", + type=int, + help="Maximum number of tests to run concurrently. The default is to run all tests concurrently.", + ) argparser.add_argument("fish", nargs=1, help="Fish to test") argparser.add_argument("file", nargs="*", help="Tests to run") args = argparser.parse_args() + max_concurrency = args.max_concurrency + if max_concurrency is not None and max_concurrency < 1: + print("--max-concurrency must be at least 1") + sys.exit(1) + fishdir = Path(args.fish[0]).absolute() if not fishdir.is_dir(): fishdir = fishdir.parent @@ -176,10 +187,19 @@ async def main(): tmp_root / "fish_test_helper", ) - tasks = [ - run_test(tmp_root, f, arg, script_path, def_subs, lconfig, fishdir) - for f, arg in files - ] + semaphore = ( + asyncio.Semaphore(max_concurrency) + if max_concurrency + else contextlib.nullcontext() + ) + + async def run(f, arg): + async with semaphore: + return await run_test( + tmp_root, f, arg, script_path, def_subs, lconfig, fishdir + ) + + tasks = [run(f, arg) for f, arg in files] for task in asyncio.as_completed(tasks): result = await task # TODO(python>3.8): use match statement