Spaces:
Running
Running
| import asyncio | |
| from typing import Awaitable, Callable, List, TypeVar | |
| from tqdm.asyncio import tqdm as tqdm_async | |
| from graphgen.utils.log import logger | |
| from .loop import create_event_loop | |
| T = TypeVar("T") | |
| R = TypeVar("R") | |
| def run_concurrent( | |
| coro_fn: Callable[[T], Awaitable[R]], | |
| items: List[T], | |
| *, | |
| desc: str = "processing", | |
| unit: str = "item", | |
| ) -> List[R]: | |
| async def _run_all(): | |
| # Wrapper to return the index alongside the result | |
| # This eliminates the need to map task IDs | |
| async def _worker(index: int, item: T): | |
| try: | |
| res = await coro_fn(item) | |
| return index, res, None | |
| except Exception as e: | |
| return index, None, e | |
| # Create tasks using the wrapper | |
| tasks_list = [ | |
| asyncio.create_task(_worker(i, item)) for i, item in enumerate(items) | |
| ] | |
| results: List[Exception | R] = [None] * len(items) | |
| pbar = tqdm_async(total=len(items), desc=desc, unit=unit) | |
| # Iterate over completed tasks | |
| for future in asyncio.as_completed(tasks_list): | |
| # We await the wrapper, which guarantees we get the index back | |
| idx, result, error = await future | |
| if error: | |
| logger.exception(f"Task failed at index {idx}: {error}") | |
| else: | |
| results[idx] = result | |
| pbar.update(1) | |
| pbar.close() | |
| return results | |
| loop = create_event_loop() | |
| try: | |
| return loop.run_until_complete(_run_all()) | |
| finally: | |
| loop.close() | |