Spaces:
Build error
Build error
| # main.py | |
| import torch | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| import base64 | |
| import io | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from contextlib import asynccontextmanager | |
| from transformers import pipeline | |
| from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler | |
| # --- API Data Models --- | |
| class StagingRequest(BaseModel): | |
| image_base64: str | |
| prompt: str | |
| negative_prompt: str = "blurry, low quality, unrealistic, distorted, ugly, watermark, text, messy, deformed, extra windows, extra doors" | |
| seed: int = 1234 | |
| # --- Global State & Model Loading --- | |
| models = {} | |
| async def lifespan(app: FastAPI): | |
| # STARTUP: Load all models | |
| print("π Server starting up: Loading AI models...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| models['segmentation_pipeline'] = pipeline("image-segmentation", model="Intel/dpt-large-ade", device=device) | |
| models['depth_estimator'] = pipeline("depth-estimation", model="Intel/dpt-hybrid-midas", device=device) | |
| controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-depth", torch_dtype=torch_dtype) | |
| models['inpainting_pipe'] = StableDiffusionControlNetInpaintPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", | |
| controlnet=controlnet, | |
| torch_dtype=torch_dtype, | |
| safety_checker=None | |
| ).to(device) | |
| models['inpainting_pipe'].scheduler = UniPCMultistepScheduler.from_config(models['inpainting_pipe'].scheduler.config) | |
| print("β All models loaded and ready.") | |
| yield | |
| # SHUTDOWN: Clean up | |
| print("β‘ Server shutting down.") | |
| models.clear() | |
| app = FastAPI(lifespan=lifespan) | |
| # --- Helper Functions (Core Logic) --- | |
| def create_precise_mask(image_pil: Image.Image) -> Image.Image: | |
| segments = models['segmentation_pipeline'](image_pil) | |
| W, H = image_pil.size | |
| inclusion_mask_np = np.zeros((H, W), dtype=np.uint8) | |
| exclusion_mask_np = np.zeros((H, W), dtype=np.uint8) | |
| inclusion_labels = {"wall", "floor", "ceiling"} | |
| base_exclusion_labels = {"door", "window", "windowpane", "window blind"} | |
| insert_labels = {"painting", "picture", "shelf", "showcase", "cabinet", "mirror", "television", "radiator"} | |
| walls, inserts = [], [] | |
| for segment in segments: | |
| label, mask = segment['label'], np.array(segment['mask']) | |
| if label in inclusion_labels: | |
| inclusion_mask_np = np.maximum(inclusion_mask_np, mask) | |
| if label == "wall": walls.append(mask) | |
| if label in base_exclusion_labels: | |
| exclusion_mask_np = np.maximum(exclusion_mask_np, mask) | |
| if label in insert_labels: | |
| inserts.append(mask) | |
| for insert_mask in inserts: | |
| for wall_mask in walls: | |
| if np.all((wall_mask >= insert_mask)[insert_mask > 0]): | |
| exclusion_mask_np = np.maximum(exclusion_mask_np, insert_mask) | |
| break | |
| raw_mask_np = np.copy(inclusion_mask_np); raw_mask_np[exclusion_mask_np > 0] = 0 | |
| mask_filled_np = cv2.morphologyEx(raw_mask_np, cv2.MORPH_CLOSE, np.ones((10,10),np.uint8)) | |
| return Image.fromarray(mask_filled_np) | |
| def generate_depth_map(image_pil: Image.Image) -> Image.Image: | |
| predicted_depth = models['depth_estimator'](image_pil)['predicted_depth'] | |
| depth_map_np = predicted_depth.cpu().numpy() | |
| depth_map_np = (depth_map_np - depth_map_np.min()) / (depth_map_np.max() - depth_map_np.min()) * 255.0 | |
| depth_map_np = depth_map_np.astype(np.uint8) | |
| return Image.fromarray(np.concatenate([depth_map_np[..., None]] * 3, axis=-1)) | |
| # --- API Endpoints --- | |
| def read_root(): | |
| return {"status": "Virtual Staging API is running."} | |
| async def furnish_room(request: StagingRequest): | |
| try: | |
| image_bytes = base64.b64decode(request.image_base64) | |
| init_image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB").resize((512, 512)) | |
| mask_image_pil = create_precise_mask(init_image_pil) | |
| control_image_pil = generate_depth_map(init_image_pil) | |
| generator = torch.Generator(device="cuda").manual_seed(request.seed) | |
| final_image = models['inpainting_pipe']( | |
| prompt=request.prompt, negative_prompt=request.negative_prompt, image=init_image_pil, | |
| mask_image=mask_image_pil, control_image=control_image_pil, | |
| num_inference_steps=30, guidance_scale=8.0, generator=generator, | |
| ).images[0] | |
| buffered = io.BytesIO() | |
| final_image.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| return {"result_image_base64": img_str} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |