Update app.py
Browse files
app.py
CHANGED
|
@@ -18,7 +18,6 @@ import torchvision
|
|
| 18 |
from torchvision.transforms.functional import to_pil_image
|
| 19 |
from huggingface_hub import hf_hub_download
|
| 20 |
|
| 21 |
-
import spaces
|
| 22 |
import gradio as gr
|
| 23 |
|
| 24 |
from transformers import SamModel, SamProcessor
|
|
@@ -26,6 +25,14 @@ from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
|
|
| 26 |
from sam2 import VQ_SAM2, VQ_SAM2Config, SAM2Config
|
| 27 |
from visualizer import sample_color, draw_mask
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
class DirectResize:
|
| 30 |
def __init__(self, target_length: int) -> None:
|
| 31 |
self.target_length = target_length
|
|
@@ -110,9 +117,7 @@ def load_vq_sam2():
|
|
| 110 |
|
| 111 |
if vq_sam2 is not None:
|
| 112 |
return vq_sam2
|
| 113 |
-
|
| 114 |
-
if hasattr(torch, "set_default_device"):
|
| 115 |
-
torch.set_default_device("cpu")
|
| 116 |
|
| 117 |
sam2_config = SAM2Config(
|
| 118 |
ckpt_path=sam2_ckpt_local,
|
|
@@ -129,7 +134,7 @@ def load_vq_sam2():
|
|
| 129 |
state = torch.load(mask_tokenizer_local, map_location="cpu")
|
| 130 |
vq_sam2.load_state_dict(state)
|
| 131 |
|
| 132 |
-
vq_sam2 = vq_sam2.
|
| 133 |
return vq_sam2
|
| 134 |
|
| 135 |
processor = AutoProcessor.from_pretrained(MODEL)
|
|
@@ -139,17 +144,18 @@ _qwen = None
|
|
| 139 |
_sam = None
|
| 140 |
|
| 141 |
def get_qwen():
|
| 142 |
-
"""Must be called only inside @spaces.GPU function."""
|
| 143 |
global _qwen
|
| 144 |
if _qwen is None:
|
| 145 |
-
_qwen = Qwen3VLForConditionalGeneration.from_pretrained(
|
|
|
|
|
|
|
|
|
|
| 146 |
return _qwen
|
| 147 |
|
| 148 |
def get_sam():
|
| 149 |
-
"""Must be called only inside @spaces.GPU function."""
|
| 150 |
global _sam
|
| 151 |
if _sam is None:
|
| 152 |
-
_sam = SamModel.from_pretrained("facebook/sam-vit-huge").to(
|
| 153 |
return _sam
|
| 154 |
|
| 155 |
colors = sample_color()
|
|
@@ -185,7 +191,6 @@ def new_mu_state():
|
|
| 185 |
"next_region_id": 1,
|
| 186 |
}
|
| 187 |
|
| 188 |
-
@spaces.GPU
|
| 189 |
def mu_on_upload_image(media_path, mu_state):
|
| 190 |
if not media_path:
|
| 191 |
return new_mu_state(), None, None
|
|
@@ -195,9 +200,9 @@ def mu_on_upload_image(media_path, mu_state):
|
|
| 195 |
img = Image.open(media_path).convert("RGB")
|
| 196 |
w, h = img.size
|
| 197 |
|
| 198 |
-
inputs = sam_processor(img, return_tensors="pt").to(
|
| 199 |
with torch.no_grad():
|
| 200 |
-
emb = sam_model.get_image_embeddings(inputs["pixel_values"]) #
|
| 201 |
|
| 202 |
st = new_mu_state()
|
| 203 |
st["image_path"] = media_path
|
|
@@ -227,10 +232,10 @@ def mu_predict_mask_from_state(mu_state):
|
|
| 227 |
input_points=[mu_state["points"]],
|
| 228 |
input_labels=[mu_state["labels"]],
|
| 229 |
return_tensors="pt",
|
| 230 |
-
).to(
|
| 231 |
|
| 232 |
# restore embedding to CUDA tensor, shape (1,256,64,64)
|
| 233 |
-
emb = torch.from_numpy(mu_state["image_embeddings"]).to(
|
| 234 |
emb = emb.unsqueeze(0)
|
| 235 |
|
| 236 |
with torch.no_grad():
|
|
@@ -253,7 +258,6 @@ def mu_predict_mask_from_state(mu_state):
|
|
| 253 |
mask = (mask > 0).astype(np.float32)
|
| 254 |
return mask
|
| 255 |
|
| 256 |
-
@spaces.GPU
|
| 257 |
def mu_add_point(evt: gr.SelectData, mu_state, is_positive: bool):
|
| 258 |
if mu_state["image_path"] is None:
|
| 259 |
return mu_state, None
|
|
@@ -266,7 +270,6 @@ def mu_add_point(evt: gr.SelectData, mu_state, is_positive: bool):
|
|
| 266 |
mu_state["cur_mask"] = mask
|
| 267 |
return mu_state, mask
|
| 268 |
|
| 269 |
-
@spaces.GPU
|
| 270 |
def mu_add_point_xy(xy, mu_state, is_positive: bool):
|
| 271 |
if mu_state["image_path"] is None:
|
| 272 |
return mu_state, None
|
|
@@ -293,7 +296,6 @@ def mu_clear_prompts(mu_state):
|
|
| 293 |
mu_state["cur_mask"] = None
|
| 294 |
return mu_state, None
|
| 295 |
|
| 296 |
-
@spaces.GPU
|
| 297 |
def mu_save_region(mu_state):
|
| 298 |
if mu_state["cur_mask"] is None:
|
| 299 |
return mu_state, gr.update(choices=[], value=None)
|
|
@@ -439,7 +441,6 @@ def replace_region_all(text: str, rid: str, token_str: str) -> str:
|
|
| 439 |
def short_tag_from_codes(code_a: int, code_b: int) -> str:
|
| 440 |
return f"<{code_a:04d}-{code_b:04d}>"
|
| 441 |
|
| 442 |
-
@spaces.GPU
|
| 443 |
def infer_understanding(mu_media, mu_query, mu_state):
|
| 444 |
model = get_qwen()
|
| 445 |
|
|
@@ -500,7 +501,6 @@ def infer_understanding(mu_media, mu_query, mu_state):
|
|
| 500 |
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 501 |
)[0]
|
| 502 |
|
| 503 |
-
@spaces.GPU
|
| 504 |
def infer_seg(media, query):
|
| 505 |
model = get_qwen()
|
| 506 |
vq_sam2 = load_vq_sam2()
|
|
@@ -642,7 +642,7 @@ def build_demo():
|
|
| 642 |
gr.HTML(HEADER)
|
| 643 |
|
| 644 |
with gr.Tab('Mask Generation'):
|
| 645 |
-
download_btn_1 = gr.DownloadButton(label='
|
| 646 |
msk_1 = gr.AnnotatedImage(label='De-tokenized 2D masks', color_map=color_map, render=False)
|
| 647 |
ans_1 = gr.HighlightedText(
|
| 648 |
label='Model Response', color_map=color_map_light, show_inline_category=False, render=False)
|
|
@@ -661,14 +661,14 @@ def build_demo():
|
|
| 661 |
)
|
| 662 |
|
| 663 |
with gr.Row():
|
| 664 |
-
random_btn_1 = gr.Button(value='
|
| 665 |
|
| 666 |
-
reset_btn_1 = gr.ClearButton([media_1, query_1, msk_1, ans_1], value='
|
| 667 |
reset_btn_1.click(reset_seg, None, [sample_frames_1, download_btn_1])
|
| 668 |
|
| 669 |
download_btn_1.render()
|
| 670 |
|
| 671 |
-
submit_btn_1 = gr.Button(value='
|
| 672 |
|
| 673 |
with gr.Column():
|
| 674 |
msk_1.render()
|
|
@@ -694,7 +694,7 @@ def build_demo():
|
|
| 694 |
)
|
| 695 |
with gr.Tab("Mask Understanding"):
|
| 696 |
MU_INSTRUCTIONS = """
|
| 697 |
-
### Mask Understanding
|
| 698 |
|
| 699 |
1. **Upload an image.**
|
| 700 |
2. **Create a region mask**
|
|
@@ -789,7 +789,12 @@ def build_demo():
|
|
| 789 |
return demo
|
| 790 |
|
| 791 |
if __name__ == '__main__':
|
| 792 |
-
|
|
|
|
|
|
|
|
|
|
| 793 |
|
|
|
|
| 794 |
demo.queue()
|
| 795 |
-
|
|
|
|
|
|
| 18 |
from torchvision.transforms.functional import to_pil_image
|
| 19 |
from huggingface_hub import hf_hub_download
|
| 20 |
|
|
|
|
| 21 |
import gradio as gr
|
| 22 |
|
| 23 |
from transformers import SamModel, SamProcessor
|
|
|
|
| 25 |
from sam2 import VQ_SAM2, VQ_SAM2Config, SAM2Config
|
| 26 |
from visualizer import sample_color, draw_mask
|
| 27 |
|
| 28 |
+
# Set the device to use GPU and ensure CUDA is available
|
| 29 |
+
DEVICE = "cuda" # dedicated GPU runtime
|
| 30 |
+
if not torch.cuda.is_available():
|
| 31 |
+
raise RuntimeError(
|
| 32 |
+
"CUDA is not available. Run the container with GPU access (e.g. --gpus all) "
|
| 33 |
+
"and ensure NVIDIA drivers + container runtime are installed."
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
class DirectResize:
|
| 37 |
def __init__(self, target_length: int) -> None:
|
| 38 |
self.target_length = target_length
|
|
|
|
| 117 |
|
| 118 |
if vq_sam2 is not None:
|
| 119 |
return vq_sam2
|
| 120 |
+
|
|
|
|
|
|
|
| 121 |
|
| 122 |
sam2_config = SAM2Config(
|
| 123 |
ckpt_path=sam2_ckpt_local,
|
|
|
|
| 134 |
state = torch.load(mask_tokenizer_local, map_location="cpu")
|
| 135 |
vq_sam2.load_state_dict(state)
|
| 136 |
|
| 137 |
+
vq_sam2 = vq_sam2.to(DEVICE).eval()
|
| 138 |
return vq_sam2
|
| 139 |
|
| 140 |
processor = AutoProcessor.from_pretrained(MODEL)
|
|
|
|
| 144 |
_sam = None
|
| 145 |
|
| 146 |
def get_qwen():
|
|
|
|
| 147 |
global _qwen
|
| 148 |
if _qwen is None:
|
| 149 |
+
_qwen = Qwen3VLForConditionalGeneration.from_pretrained(
|
| 150 |
+
MODEL,
|
| 151 |
+
torch_dtype="auto",
|
| 152 |
+
).to(DEVICE).eval()
|
| 153 |
return _qwen
|
| 154 |
|
| 155 |
def get_sam():
|
|
|
|
| 156 |
global _sam
|
| 157 |
if _sam is None:
|
| 158 |
+
_sam = SamModel.from_pretrained("facebook/sam-vit-huge").to(DEVICE).eval()
|
| 159 |
return _sam
|
| 160 |
|
| 161 |
colors = sample_color()
|
|
|
|
| 191 |
"next_region_id": 1,
|
| 192 |
}
|
| 193 |
|
|
|
|
| 194 |
def mu_on_upload_image(media_path, mu_state):
|
| 195 |
if not media_path:
|
| 196 |
return new_mu_state(), None, None
|
|
|
|
| 200 |
img = Image.open(media_path).convert("RGB")
|
| 201 |
w, h = img.size
|
| 202 |
|
| 203 |
+
inputs = sam_processor(img, return_tensors="pt").to(DEVICE)
|
| 204 |
with torch.no_grad():
|
| 205 |
+
emb = sam_model.get_image_embeddings(inputs["pixel_values"]) # tensor on DEVICE
|
| 206 |
|
| 207 |
st = new_mu_state()
|
| 208 |
st["image_path"] = media_path
|
|
|
|
| 232 |
input_points=[mu_state["points"]],
|
| 233 |
input_labels=[mu_state["labels"]],
|
| 234 |
return_tensors="pt",
|
| 235 |
+
).to(DEVICE)
|
| 236 |
|
| 237 |
# restore embedding to CUDA tensor, shape (1,256,64,64)
|
| 238 |
+
emb = torch.from_numpy(mu_state["image_embeddings"]).to(DEVICE)
|
| 239 |
emb = emb.unsqueeze(0)
|
| 240 |
|
| 241 |
with torch.no_grad():
|
|
|
|
| 258 |
mask = (mask > 0).astype(np.float32)
|
| 259 |
return mask
|
| 260 |
|
|
|
|
| 261 |
def mu_add_point(evt: gr.SelectData, mu_state, is_positive: bool):
|
| 262 |
if mu_state["image_path"] is None:
|
| 263 |
return mu_state, None
|
|
|
|
| 270 |
mu_state["cur_mask"] = mask
|
| 271 |
return mu_state, mask
|
| 272 |
|
|
|
|
| 273 |
def mu_add_point_xy(xy, mu_state, is_positive: bool):
|
| 274 |
if mu_state["image_path"] is None:
|
| 275 |
return mu_state, None
|
|
|
|
| 296 |
mu_state["cur_mask"] = None
|
| 297 |
return mu_state, None
|
| 298 |
|
|
|
|
| 299 |
def mu_save_region(mu_state):
|
| 300 |
if mu_state["cur_mask"] is None:
|
| 301 |
return mu_state, gr.update(choices=[], value=None)
|
|
|
|
| 441 |
def short_tag_from_codes(code_a: int, code_b: int) -> str:
|
| 442 |
return f"<{code_a:04d}-{code_b:04d}>"
|
| 443 |
|
|
|
|
| 444 |
def infer_understanding(mu_media, mu_query, mu_state):
|
| 445 |
model = get_qwen()
|
| 446 |
|
|
|
|
| 501 |
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 502 |
)[0]
|
| 503 |
|
|
|
|
| 504 |
def infer_seg(media, query):
|
| 505 |
model = get_qwen()
|
| 506 |
vq_sam2 = load_vq_sam2()
|
|
|
|
| 642 |
gr.HTML(HEADER)
|
| 643 |
|
| 644 |
with gr.Tab('Mask Generation'):
|
| 645 |
+
download_btn_1 = gr.DownloadButton(label='ƒôª Download', interactive=False, render=False)
|
| 646 |
msk_1 = gr.AnnotatedImage(label='De-tokenized 2D masks', color_map=color_map, render=False)
|
| 647 |
ans_1 = gr.HighlightedText(
|
| 648 |
label='Model Response', color_map=color_map_light, show_inline_category=False, render=False)
|
|
|
|
| 661 |
)
|
| 662 |
|
| 663 |
with gr.Row():
|
| 664 |
+
random_btn_1 = gr.Button(value='ƒö« Random', visible=False)
|
| 665 |
|
| 666 |
+
reset_btn_1 = gr.ClearButton([media_1, query_1, msk_1, ans_1], value='ƒùæ´©Å Reset')
|
| 667 |
reset_btn_1.click(reset_seg, None, [sample_frames_1, download_btn_1])
|
| 668 |
|
| 669 |
download_btn_1.render()
|
| 670 |
|
| 671 |
+
submit_btn_1 = gr.Button(value='ƒÜÇ Submit', variant='primary', elem_id='submit_1')
|
| 672 |
|
| 673 |
with gr.Column():
|
| 674 |
msk_1.render()
|
|
|
|
| 694 |
)
|
| 695 |
with gr.Tab("Mask Understanding"):
|
| 696 |
MU_INSTRUCTIONS = """
|
| 697 |
+
### Mask Understanding ÔÇö Instructions
|
| 698 |
|
| 699 |
1. **Upload an image.**
|
| 700 |
2. **Create a region mask**
|
|
|
|
| 789 |
return demo
|
| 790 |
|
| 791 |
if __name__ == '__main__':
|
| 792 |
+
# Warm-up: load all heavy models once at startup (dedicated GPU server)
|
| 793 |
+
get_qwen()
|
| 794 |
+
get_sam()
|
| 795 |
+
load_vq_sam2()
|
| 796 |
|
| 797 |
+
demo = build_demo()
|
| 798 |
demo.queue()
|
| 799 |
+
port = int(os.getenv("PORT", "7860"))
|
| 800 |
+
demo.launch(server_name="0.0.0.0", server_port=port)
|