Spaces:
Running
Running
| import asyncio | |
| from typing import Awaitable, Callable, List, Optional, TypeVar | |
| import gradio as gr | |
| from tqdm.asyncio import tqdm as tqdm_async | |
| from graphgen.utils.log import logger | |
| T = TypeVar("T") | |
| R = TypeVar("R") | |
| # async def run_concurrent( | |
| # coro_fn: Callable[[T], Awaitable[R]], | |
| # items: List[T], | |
| # *, | |
| # desc: str = "processing", | |
| # unit: str = "item", | |
| # progress_bar: Optional[gr.Progress] = None, | |
| # ) -> List[R]: | |
| # tasks = [asyncio.create_task(coro_fn(it)) for it in items] | |
| # | |
| # results = [] | |
| # async for future in tqdm_async( | |
| # tasks, desc=desc, unit=unit | |
| # ): | |
| # try: | |
| # result = await future | |
| # results.append(result) | |
| # except Exception as e: # pylint: disable=broad-except | |
| # logger.exception("Task failed: %s", e) | |
| # | |
| # if progress_bar is not None: | |
| # progress_bar((len(results)) / len(items), desc=desc) | |
| # | |
| # if progress_bar is not None: | |
| # progress_bar(1.0, desc=desc) | |
| # return results | |
| # results = await tqdm_async.gather(*tasks, desc=desc, unit=unit) | |
| # | |
| # ok_results = [] | |
| # for idx, res in enumerate(results): | |
| # if isinstance(res, Exception): | |
| # logger.exception("Task failed: %s", res) | |
| # if progress_bar: | |
| # progress_bar((idx + 1) / len(items), desc=desc) | |
| # continue | |
| # ok_results.append(res) | |
| # if progress_bar: | |
| # progress_bar((idx + 1) / len(items), desc=desc) | |
| # | |
| # if progress_bar: | |
| # progress_bar(1.0, desc=desc) | |
| # return ok_results | |
| # async def run_concurrent( | |
| # coro_fn: Callable[[T], Awaitable[R]], | |
| # items: List[T], | |
| # *, | |
| # desc: str = "processing", | |
| # unit: str = "item", | |
| # progress_bar: Optional[gr.Progress] = None, | |
| # ) -> List[R]: | |
| # tasks = [asyncio.create_task(coro_fn(it)) for it in items] | |
| # | |
| # results = [] | |
| # # 使用同步方式更新进度条,避免异步冲突 | |
| # for i, task in enumerate(asyncio.as_completed(tasks)): | |
| # try: | |
| # result = await task | |
| # results.append(result) | |
| # # 同步更新进度条 | |
| # if progress_bar is not None: | |
| # # 在同步上下文中更新进度 | |
| # progress_bar((i + 1) / len(items), desc=desc) | |
| # except Exception as e: | |
| # logger.exception("Task failed: %s", e) | |
| # results.append(e) | |
| # | |
| # return results | |
| async def run_concurrent( | |
| coro_fn: Callable[[T], Awaitable[R]], | |
| items: List[T], | |
| *, | |
| desc: str = "processing", | |
| unit: str = "item", | |
| progress_bar: Optional[gr.Progress] = None, | |
| ) -> List[R]: | |
| tasks = [asyncio.create_task(coro_fn(it)) for it in items] | |
| completed_count = 0 | |
| results = [] | |
| pbar = tqdm_async(total=len(items), desc=desc, unit=unit) | |
| if progress_bar is not None: | |
| progress_bar(0.0, desc=f"{desc} (0/{len(items)})") | |
| for future in asyncio.as_completed(tasks): | |
| try: | |
| result = await future | |
| results.append(result) | |
| except Exception as e: # pylint: disable=broad-except | |
| logger.exception("Task failed: %s", e) | |
| # even if failed, record it to keep results consistent with tasks | |
| results.append(e) | |
| completed_count += 1 | |
| pbar.update(1) | |
| if progress_bar is not None: | |
| progress = completed_count / len(items) | |
| progress_bar(progress, desc=f"{desc} ({completed_count}/{len(items)})") | |
| pbar.close() | |
| if progress_bar is not None: | |
| progress_bar(1.0, desc=f"{desc} (completed)") | |
| # filter out exceptions | |
| results = [res for res in results if not isinstance(res, Exception)] | |
| return results | |