| | import json |
| | from pathlib import Path |
| | from typing import Dict, Optional |
| |
|
| | import cv2 |
| | import psutil |
| | from PIL import Image |
| | from loguru import logger |
| | from rich.console import Console |
| | from rich.progress import ( |
| | Progress, |
| | SpinnerColumn, |
| | TimeElapsedColumn, |
| | MofNCompleteColumn, |
| | TextColumn, |
| | BarColumn, |
| | TaskProgressColumn, |
| | ) |
| |
|
| | from iopaint.helper import pil_to_bytes |
| | from iopaint.model.utils import torch_gc |
| | from iopaint.model_manager import ModelManager |
| | from iopaint.schema import InpaintRequest |
| |
|
| |
|
| | def glob_images(path: Path) -> Dict[str, Path]: |
| | |
| | if path.is_file(): |
| | return {path.stem: path} |
| | elif path.is_dir(): |
| | res = {} |
| | for it in path.glob("*.*"): |
| | if it.suffix.lower() in [".png", ".jpg", ".jpeg"]: |
| | res[it.stem] = it |
| | return res |
| |
|
| |
|
| | def batch_inpaint( |
| | model: str, |
| | device, |
| | image: Path, |
| | mask: Path, |
| | output: Path, |
| | config: Optional[Path] = None, |
| | concat: bool = False, |
| | ): |
| | if image.is_dir() and output.is_file(): |
| | logger.error( |
| | f"invalid --output: when image is a directory, output should be a directory" |
| | ) |
| | exit(-1) |
| | output.mkdir(parents=True, exist_ok=True) |
| |
|
| | image_paths = glob_images(image) |
| | mask_paths = glob_images(mask) |
| | if len(image_paths) == 0: |
| | logger.error(f"invalid --image: empty image folder") |
| | exit(-1) |
| | if len(mask_paths) == 0: |
| | logger.error(f"invalid --mask: empty mask folder") |
| | exit(-1) |
| |
|
| | if config is None: |
| | inpaint_request = InpaintRequest() |
| | logger.info(f"Using default config: {inpaint_request}") |
| | else: |
| | with open(config, "r", encoding="utf-8") as f: |
| | inpaint_request = InpaintRequest(**json.load(f)) |
| |
|
| | model_manager = ModelManager(name=model, device=device) |
| | first_mask = list(mask_paths.values())[0] |
| |
|
| | console = Console() |
| |
|
| | with Progress( |
| | SpinnerColumn(), |
| | TextColumn("[progress.description]{task.description}"), |
| | BarColumn(), |
| | TaskProgressColumn(), |
| | MofNCompleteColumn(), |
| | TimeElapsedColumn(), |
| | console=console, |
| | transient=False, |
| | ) as progress: |
| | task = progress.add_task("Batch processing...", total=len(image_paths)) |
| | for stem, image_p in image_paths.items(): |
| | if stem not in mask_paths and mask.is_dir(): |
| | progress.log(f"mask for {image_p} not found") |
| | progress.update(task, advance=1) |
| | continue |
| | mask_p = mask_paths.get(stem, first_mask) |
| |
|
| | infos = Image.open(image_p).info |
| |
|
| | img = cv2.imread(str(image_p)) |
| | img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB) |
| | mask_img = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE) |
| | if mask_img.shape[:2] != img.shape[:2]: |
| | progress.log( |
| | f"resize mask {mask_p.name} to image {image_p.name} size: {img.shape[:2]}" |
| | ) |
| | mask_img = cv2.resize( |
| | mask_img, |
| | (img.shape[1], img.shape[0]), |
| | interpolation=cv2.INTER_NEAREST, |
| | ) |
| | mask_img[mask_img >= 127] = 255 |
| | mask_img[mask_img < 127] = 0 |
| |
|
| | |
| | inpaint_result = model_manager(img, mask_img, inpaint_request) |
| | inpaint_result = cv2.cvtColor(inpaint_result, cv2.COLOR_BGR2RGB) |
| | if concat: |
| | mask_img = cv2.cvtColor(mask_img, cv2.COLOR_GRAY2RGB) |
| | inpaint_result = cv2.hconcat([img, mask_img, inpaint_result]) |
| |
|
| | img_bytes = pil_to_bytes(Image.fromarray(inpaint_result), "png", 100, infos) |
| | save_p = output / f"{stem}.png" |
| | with open(save_p, "wb") as fw: |
| | fw.write(img_bytes) |
| |
|
| | progress.update(task, advance=1) |
| | torch_gc() |
| | |
| | |
| | |
| | |
| |
|