# main.py import torch import numpy as np import cv2 from PIL import Image import base64 import io from fastapi import FastAPI, HTTPException from pydantic import BaseModel from contextlib import asynccontextmanager from transformers import pipeline from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler # --- API Data Models --- class StagingRequest(BaseModel): image_base64: str prompt: str negative_prompt: str = "blurry, low quality, unrealistic, distorted, ugly, watermark, text, messy, deformed, extra windows, extra doors" seed: int = 1234 # --- Global State & Model Loading --- models = {} @asynccontextmanager async def lifespan(app: FastAPI): # STARTUP: Load all models print("🚀 Server starting up: Loading AI models...") device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 models['segmentation_pipeline'] = pipeline("image-segmentation", model="Intel/dpt-large-ade", device=device) models['depth_estimator'] = pipeline("depth-estimation", model="Intel/dpt-hybrid-midas", device=device) controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-depth", torch_dtype=torch_dtype) models['inpainting_pipe'] = StableDiffusionControlNetInpaintPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch_dtype, safety_checker=None ).to(device) models['inpainting_pipe'].scheduler = UniPCMultistepScheduler.from_config(models['inpainting_pipe'].scheduler.config) print("✅ All models loaded and ready.") yield # SHUTDOWN: Clean up print("⚡ Server shutting down.") models.clear() app = FastAPI(lifespan=lifespan) # --- Helper Functions (Core Logic) --- def create_precise_mask(image_pil: Image.Image) -> Image.Image: segments = models['segmentation_pipeline'](image_pil) W, H = image_pil.size inclusion_mask_np = np.zeros((H, W), dtype=np.uint8) exclusion_mask_np = np.zeros((H, W), dtype=np.uint8) inclusion_labels = {"wall", "floor", "ceiling"} base_exclusion_labels = {"door", "window", "windowpane", "window blind"} insert_labels = {"painting", "picture", "shelf", "showcase", "cabinet", "mirror", "television", "radiator"} walls, inserts = [], [] for segment in segments: label, mask = segment['label'], np.array(segment['mask']) if label in inclusion_labels: inclusion_mask_np = np.maximum(inclusion_mask_np, mask) if label == "wall": walls.append(mask) if label in base_exclusion_labels: exclusion_mask_np = np.maximum(exclusion_mask_np, mask) if label in insert_labels: inserts.append(mask) for insert_mask in inserts: for wall_mask in walls: if np.all((wall_mask >= insert_mask)[insert_mask > 0]): exclusion_mask_np = np.maximum(exclusion_mask_np, insert_mask) break raw_mask_np = np.copy(inclusion_mask_np); raw_mask_np[exclusion_mask_np > 0] = 0 mask_filled_np = cv2.morphologyEx(raw_mask_np, cv2.MORPH_CLOSE, np.ones((10,10),np.uint8)) return Image.fromarray(mask_filled_np) def generate_depth_map(image_pil: Image.Image) -> Image.Image: predicted_depth = models['depth_estimator'](image_pil)['predicted_depth'] depth_map_np = predicted_depth.cpu().numpy() depth_map_np = (depth_map_np - depth_map_np.min()) / (depth_map_np.max() - depth_map_np.min()) * 255.0 depth_map_np = depth_map_np.astype(np.uint8) return Image.fromarray(np.concatenate([depth_map_np[..., None]] * 3, axis=-1)) # --- API Endpoints --- @app.get("/") def read_root(): return {"status": "Virtual Staging API is running."} @app.post("/furnish-room/") async def furnish_room(request: StagingRequest): try: image_bytes = base64.b64decode(request.image_base64) init_image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB").resize((512, 512)) mask_image_pil = create_precise_mask(init_image_pil) control_image_pil = generate_depth_map(init_image_pil) generator = torch.Generator(device="cuda").manual_seed(request.seed) final_image = models['inpainting_pipe']( prompt=request.prompt, negative_prompt=request.negative_prompt, image=init_image_pil, mask_image=mask_image_pil, control_image=control_image_pil, num_inference_steps=30, guidance_scale=8.0, generator=generator, ).images[0] buffered = io.BytesIO() final_image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") return {"result_image_base64": img_str} except Exception as e: raise HTTPException(status_code=500, detail=str(e))