File size: 4,920 Bytes
8f78b88
 
 
 
 
 
 
 
 
 
 
9fe5653
8f78b88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# main.py
import torch
import numpy as np
import cv2
from PIL import Image
import base64
import io
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from contextlib import asynccontextmanager
from transformers import pipeline
from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler

# --- API Data Models ---
class StagingRequest(BaseModel):
    image_base64: str
    prompt: str
    negative_prompt: str = "blurry, low quality, unrealistic, distorted, ugly, watermark, text, messy, deformed, extra windows, extra doors"
    seed: int = 1234

# --- Global State & Model Loading ---
models = {}

@asynccontextmanager
async def lifespan(app: FastAPI):
    # STARTUP: Load all models
    print("🚀 Server starting up: Loading AI models...")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

    models['segmentation_pipeline'] = pipeline("image-segmentation", model="Intel/dpt-large-ade", device=device)
    models['depth_estimator'] = pipeline("depth-estimation", model="Intel/dpt-hybrid-midas", device=device)
    
    controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-depth", torch_dtype=torch_dtype)
    
    models['inpainting_pipe'] = StableDiffusionControlNetInpaintPipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5",
        controlnet=controlnet,
        torch_dtype=torch_dtype,
        safety_checker=None
    ).to(device)
    models['inpainting_pipe'].scheduler = UniPCMultistepScheduler.from_config(models['inpainting_pipe'].scheduler.config)
    
    print("✅ All models loaded and ready.")
    yield
    # SHUTDOWN: Clean up
    print("⚡ Server shutting down.")
    models.clear()

app = FastAPI(lifespan=lifespan)

# --- Helper Functions (Core Logic) ---
def create_precise_mask(image_pil: Image.Image) -> Image.Image:
    segments = models['segmentation_pipeline'](image_pil)
    W, H = image_pil.size
    inclusion_mask_np = np.zeros((H, W), dtype=np.uint8)
    exclusion_mask_np = np.zeros((H, W), dtype=np.uint8)
    inclusion_labels = {"wall", "floor", "ceiling"}
    base_exclusion_labels = {"door", "window", "windowpane", "window blind"}
    insert_labels = {"painting", "picture", "shelf", "showcase", "cabinet", "mirror", "television", "radiator"}
    walls, inserts = [], []
    for segment in segments:
        label, mask = segment['label'], np.array(segment['mask'])
        if label in inclusion_labels:
            inclusion_mask_np = np.maximum(inclusion_mask_np, mask)
            if label == "wall": walls.append(mask)
        if label in base_exclusion_labels:
            exclusion_mask_np = np.maximum(exclusion_mask_np, mask)
        if label in insert_labels:
            inserts.append(mask)
    for insert_mask in inserts:
        for wall_mask in walls:
            if np.all((wall_mask >= insert_mask)[insert_mask > 0]):
                exclusion_mask_np = np.maximum(exclusion_mask_np, insert_mask)
                break
    raw_mask_np = np.copy(inclusion_mask_np); raw_mask_np[exclusion_mask_np > 0] = 0
    mask_filled_np = cv2.morphologyEx(raw_mask_np, cv2.MORPH_CLOSE, np.ones((10,10),np.uint8))
    return Image.fromarray(mask_filled_np)

def generate_depth_map(image_pil: Image.Image) -> Image.Image:
    predicted_depth = models['depth_estimator'](image_pil)['predicted_depth']
    depth_map_np = predicted_depth.cpu().numpy()
    depth_map_np = (depth_map_np - depth_map_np.min()) / (depth_map_np.max() - depth_map_np.min()) * 255.0
    depth_map_np = depth_map_np.astype(np.uint8)
    return Image.fromarray(np.concatenate([depth_map_np[..., None]] * 3, axis=-1))

# --- API Endpoints ---
@app.get("/")
def read_root():
    return {"status": "Virtual Staging API is running."}

@app.post("/furnish-room/")
async def furnish_room(request: StagingRequest):
    try:
        image_bytes = base64.b64decode(request.image_base64)
        init_image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB").resize((512, 512))
        
        mask_image_pil = create_precise_mask(init_image_pil)
        control_image_pil = generate_depth_map(init_image_pil)
        
        generator = torch.Generator(device="cuda").manual_seed(request.seed)
        final_image = models['inpainting_pipe'](
            prompt=request.prompt, negative_prompt=request.negative_prompt, image=init_image_pil,
            mask_image=mask_image_pil, control_image=control_image_pil,
            num_inference_steps=30, guidance_scale=8.0, generator=generator,
        ).images[0]
        
        buffered = io.BytesIO()
        final_image.save(buffered, format="PNG")
        img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
        
        return {"result_image_base64": img_str}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))