ayushpfullstack's picture
Update main.py
9fe5653 verified
raw
history blame
4.92 kB
# main.py
import torch
import numpy as np
import cv2
from PIL import Image
import base64
import io
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from contextlib import asynccontextmanager
from transformers import pipeline
from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler
# --- API Data Models ---
class StagingRequest(BaseModel):
image_base64: str
prompt: str
negative_prompt: str = "blurry, low quality, unrealistic, distorted, ugly, watermark, text, messy, deformed, extra windows, extra doors"
seed: int = 1234
# --- Global State & Model Loading ---
models = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
# STARTUP: Load all models
print("πŸš€ Server starting up: Loading AI models...")
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
models['segmentation_pipeline'] = pipeline("image-segmentation", model="Intel/dpt-large-ade", device=device)
models['depth_estimator'] = pipeline("depth-estimation", model="Intel/dpt-hybrid-midas", device=device)
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-depth", torch_dtype=torch_dtype)
models['inpainting_pipe'] = StableDiffusionControlNetInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
controlnet=controlnet,
torch_dtype=torch_dtype,
safety_checker=None
).to(device)
models['inpainting_pipe'].scheduler = UniPCMultistepScheduler.from_config(models['inpainting_pipe'].scheduler.config)
print("βœ… All models loaded and ready.")
yield
# SHUTDOWN: Clean up
print("⚑ Server shutting down.")
models.clear()
app = FastAPI(lifespan=lifespan)
# --- Helper Functions (Core Logic) ---
def create_precise_mask(image_pil: Image.Image) -> Image.Image:
segments = models['segmentation_pipeline'](image_pil)
W, H = image_pil.size
inclusion_mask_np = np.zeros((H, W), dtype=np.uint8)
exclusion_mask_np = np.zeros((H, W), dtype=np.uint8)
inclusion_labels = {"wall", "floor", "ceiling"}
base_exclusion_labels = {"door", "window", "windowpane", "window blind"}
insert_labels = {"painting", "picture", "shelf", "showcase", "cabinet", "mirror", "television", "radiator"}
walls, inserts = [], []
for segment in segments:
label, mask = segment['label'], np.array(segment['mask'])
if label in inclusion_labels:
inclusion_mask_np = np.maximum(inclusion_mask_np, mask)
if label == "wall": walls.append(mask)
if label in base_exclusion_labels:
exclusion_mask_np = np.maximum(exclusion_mask_np, mask)
if label in insert_labels:
inserts.append(mask)
for insert_mask in inserts:
for wall_mask in walls:
if np.all((wall_mask >= insert_mask)[insert_mask > 0]):
exclusion_mask_np = np.maximum(exclusion_mask_np, insert_mask)
break
raw_mask_np = np.copy(inclusion_mask_np); raw_mask_np[exclusion_mask_np > 0] = 0
mask_filled_np = cv2.morphologyEx(raw_mask_np, cv2.MORPH_CLOSE, np.ones((10,10),np.uint8))
return Image.fromarray(mask_filled_np)
def generate_depth_map(image_pil: Image.Image) -> Image.Image:
predicted_depth = models['depth_estimator'](image_pil)['predicted_depth']
depth_map_np = predicted_depth.cpu().numpy()
depth_map_np = (depth_map_np - depth_map_np.min()) / (depth_map_np.max() - depth_map_np.min()) * 255.0
depth_map_np = depth_map_np.astype(np.uint8)
return Image.fromarray(np.concatenate([depth_map_np[..., None]] * 3, axis=-1))
# --- API Endpoints ---
@app.get("/")
def read_root():
return {"status": "Virtual Staging API is running."}
@app.post("/furnish-room/")
async def furnish_room(request: StagingRequest):
try:
image_bytes = base64.b64decode(request.image_base64)
init_image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB").resize((512, 512))
mask_image_pil = create_precise_mask(init_image_pil)
control_image_pil = generate_depth_map(init_image_pil)
generator = torch.Generator(device="cuda").manual_seed(request.seed)
final_image = models['inpainting_pipe'](
prompt=request.prompt, negative_prompt=request.negative_prompt, image=init_image_pil,
mask_image=mask_image_pil, control_image=control_image_pil,
num_inference_steps=30, guidance_scale=8.0, generator=generator,
).images[0]
buffered = io.BytesIO()
final_image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return {"result_image_base64": img_str}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))