CultriX commited on
Commit
8304eb6
·
verified ·
1 Parent(s): 80f498e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -26
app.py CHANGED
@@ -18,7 +18,6 @@ import torchvision
18
  from torchvision.transforms.functional import to_pil_image
19
  from huggingface_hub import hf_hub_download
20
 
21
- import spaces
22
  import gradio as gr
23
 
24
  from transformers import SamModel, SamProcessor
@@ -26,6 +25,14 @@ from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
26
  from sam2 import VQ_SAM2, VQ_SAM2Config, SAM2Config
27
  from visualizer import sample_color, draw_mask
28
 
 
 
 
 
 
 
 
 
29
  class DirectResize:
30
  def __init__(self, target_length: int) -> None:
31
  self.target_length = target_length
@@ -110,9 +117,7 @@ def load_vq_sam2():
110
 
111
  if vq_sam2 is not None:
112
  return vq_sam2
113
-
114
- if hasattr(torch, "set_default_device"):
115
- torch.set_default_device("cpu")
116
 
117
  sam2_config = SAM2Config(
118
  ckpt_path=sam2_ckpt_local,
@@ -129,7 +134,7 @@ def load_vq_sam2():
129
  state = torch.load(mask_tokenizer_local, map_location="cpu")
130
  vq_sam2.load_state_dict(state)
131
 
132
- vq_sam2 = vq_sam2.cuda().eval()
133
  return vq_sam2
134
 
135
  processor = AutoProcessor.from_pretrained(MODEL)
@@ -139,17 +144,18 @@ _qwen = None
139
  _sam = None
140
 
141
  def get_qwen():
142
- """Must be called only inside @spaces.GPU function."""
143
  global _qwen
144
  if _qwen is None:
145
- _qwen = Qwen3VLForConditionalGeneration.from_pretrained(MODEL, torch_dtype="auto").to("cuda").eval()
 
 
 
146
  return _qwen
147
 
148
  def get_sam():
149
- """Must be called only inside @spaces.GPU function."""
150
  global _sam
151
  if _sam is None:
152
- _sam = SamModel.from_pretrained("facebook/sam-vit-huge").to("cuda").eval()
153
  return _sam
154
 
155
  colors = sample_color()
@@ -185,7 +191,6 @@ def new_mu_state():
185
  "next_region_id": 1,
186
  }
187
 
188
- @spaces.GPU
189
  def mu_on_upload_image(media_path, mu_state):
190
  if not media_path:
191
  return new_mu_state(), None, None
@@ -195,9 +200,9 @@ def mu_on_upload_image(media_path, mu_state):
195
  img = Image.open(media_path).convert("RGB")
196
  w, h = img.size
197
 
198
- inputs = sam_processor(img, return_tensors="pt").to("cuda")
199
  with torch.no_grad():
200
- emb = sam_model.get_image_embeddings(inputs["pixel_values"]) # CUDA tensor
201
 
202
  st = new_mu_state()
203
  st["image_path"] = media_path
@@ -227,10 +232,10 @@ def mu_predict_mask_from_state(mu_state):
227
  input_points=[mu_state["points"]],
228
  input_labels=[mu_state["labels"]],
229
  return_tensors="pt",
230
- ).to("cuda")
231
 
232
  # restore embedding to CUDA tensor, shape (1,256,64,64)
233
- emb = torch.from_numpy(mu_state["image_embeddings"]).to("cuda")
234
  emb = emb.unsqueeze(0)
235
 
236
  with torch.no_grad():
@@ -253,7 +258,6 @@ def mu_predict_mask_from_state(mu_state):
253
  mask = (mask > 0).astype(np.float32)
254
  return mask
255
 
256
- @spaces.GPU
257
  def mu_add_point(evt: gr.SelectData, mu_state, is_positive: bool):
258
  if mu_state["image_path"] is None:
259
  return mu_state, None
@@ -266,7 +270,6 @@ def mu_add_point(evt: gr.SelectData, mu_state, is_positive: bool):
266
  mu_state["cur_mask"] = mask
267
  return mu_state, mask
268
 
269
- @spaces.GPU
270
  def mu_add_point_xy(xy, mu_state, is_positive: bool):
271
  if mu_state["image_path"] is None:
272
  return mu_state, None
@@ -293,7 +296,6 @@ def mu_clear_prompts(mu_state):
293
  mu_state["cur_mask"] = None
294
  return mu_state, None
295
 
296
- @spaces.GPU
297
  def mu_save_region(mu_state):
298
  if mu_state["cur_mask"] is None:
299
  return mu_state, gr.update(choices=[], value=None)
@@ -439,7 +441,6 @@ def replace_region_all(text: str, rid: str, token_str: str) -> str:
439
  def short_tag_from_codes(code_a: int, code_b: int) -> str:
440
  return f"<{code_a:04d}-{code_b:04d}>"
441
 
442
- @spaces.GPU
443
  def infer_understanding(mu_media, mu_query, mu_state):
444
  model = get_qwen()
445
 
@@ -500,7 +501,6 @@ def infer_understanding(mu_media, mu_query, mu_state):
500
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
501
  )[0]
502
 
503
- @spaces.GPU
504
  def infer_seg(media, query):
505
  model = get_qwen()
506
  vq_sam2 = load_vq_sam2()
@@ -642,7 +642,7 @@ def build_demo():
642
  gr.HTML(HEADER)
643
 
644
  with gr.Tab('Mask Generation'):
645
- download_btn_1 = gr.DownloadButton(label='📦 Download', interactive=False, render=False)
646
  msk_1 = gr.AnnotatedImage(label='De-tokenized 2D masks', color_map=color_map, render=False)
647
  ans_1 = gr.HighlightedText(
648
  label='Model Response', color_map=color_map_light, show_inline_category=False, render=False)
@@ -661,14 +661,14 @@ def build_demo():
661
  )
662
 
663
  with gr.Row():
664
- random_btn_1 = gr.Button(value='🔮 Random', visible=False)
665
 
666
- reset_btn_1 = gr.ClearButton([media_1, query_1, msk_1, ans_1], value='🗑️ Reset')
667
  reset_btn_1.click(reset_seg, None, [sample_frames_1, download_btn_1])
668
 
669
  download_btn_1.render()
670
 
671
- submit_btn_1 = gr.Button(value='🚀 Submit', variant='primary', elem_id='submit_1')
672
 
673
  with gr.Column():
674
  msk_1.render()
@@ -694,7 +694,7 @@ def build_demo():
694
  )
695
  with gr.Tab("Mask Understanding"):
696
  MU_INSTRUCTIONS = """
697
- ### Mask Understanding Instructions
698
 
699
  1. **Upload an image.**
700
  2. **Create a region mask**
@@ -789,7 +789,12 @@ def build_demo():
789
  return demo
790
 
791
  if __name__ == '__main__':
792
- demo = build_demo()
 
 
 
793
 
 
794
  demo.queue()
795
- demo.launch()
 
 
18
  from torchvision.transforms.functional import to_pil_image
19
  from huggingface_hub import hf_hub_download
20
 
 
21
  import gradio as gr
22
 
23
  from transformers import SamModel, SamProcessor
 
25
  from sam2 import VQ_SAM2, VQ_SAM2Config, SAM2Config
26
  from visualizer import sample_color, draw_mask
27
 
28
+ # Set the device to use GPU and ensure CUDA is available
29
+ DEVICE = "cuda" # dedicated GPU runtime
30
+ if not torch.cuda.is_available():
31
+ raise RuntimeError(
32
+ "CUDA is not available. Run the container with GPU access (e.g. --gpus all) "
33
+ "and ensure NVIDIA drivers + container runtime are installed."
34
+ )
35
+
36
  class DirectResize:
37
  def __init__(self, target_length: int) -> None:
38
  self.target_length = target_length
 
117
 
118
  if vq_sam2 is not None:
119
  return vq_sam2
120
+
 
 
121
 
122
  sam2_config = SAM2Config(
123
  ckpt_path=sam2_ckpt_local,
 
134
  state = torch.load(mask_tokenizer_local, map_location="cpu")
135
  vq_sam2.load_state_dict(state)
136
 
137
+ vq_sam2 = vq_sam2.to(DEVICE).eval()
138
  return vq_sam2
139
 
140
  processor = AutoProcessor.from_pretrained(MODEL)
 
144
  _sam = None
145
 
146
  def get_qwen():
 
147
  global _qwen
148
  if _qwen is None:
149
+ _qwen = Qwen3VLForConditionalGeneration.from_pretrained(
150
+ MODEL,
151
+ torch_dtype="auto",
152
+ ).to(DEVICE).eval()
153
  return _qwen
154
 
155
  def get_sam():
 
156
  global _sam
157
  if _sam is None:
158
+ _sam = SamModel.from_pretrained("facebook/sam-vit-huge").to(DEVICE).eval()
159
  return _sam
160
 
161
  colors = sample_color()
 
191
  "next_region_id": 1,
192
  }
193
 
 
194
  def mu_on_upload_image(media_path, mu_state):
195
  if not media_path:
196
  return new_mu_state(), None, None
 
200
  img = Image.open(media_path).convert("RGB")
201
  w, h = img.size
202
 
203
+ inputs = sam_processor(img, return_tensors="pt").to(DEVICE)
204
  with torch.no_grad():
205
+ emb = sam_model.get_image_embeddings(inputs["pixel_values"]) # tensor on DEVICE
206
 
207
  st = new_mu_state()
208
  st["image_path"] = media_path
 
232
  input_points=[mu_state["points"]],
233
  input_labels=[mu_state["labels"]],
234
  return_tensors="pt",
235
+ ).to(DEVICE)
236
 
237
  # restore embedding to CUDA tensor, shape (1,256,64,64)
238
+ emb = torch.from_numpy(mu_state["image_embeddings"]).to(DEVICE)
239
  emb = emb.unsqueeze(0)
240
 
241
  with torch.no_grad():
 
258
  mask = (mask > 0).astype(np.float32)
259
  return mask
260
 
 
261
  def mu_add_point(evt: gr.SelectData, mu_state, is_positive: bool):
262
  if mu_state["image_path"] is None:
263
  return mu_state, None
 
270
  mu_state["cur_mask"] = mask
271
  return mu_state, mask
272
 
 
273
  def mu_add_point_xy(xy, mu_state, is_positive: bool):
274
  if mu_state["image_path"] is None:
275
  return mu_state, None
 
296
  mu_state["cur_mask"] = None
297
  return mu_state, None
298
 
 
299
  def mu_save_region(mu_state):
300
  if mu_state["cur_mask"] is None:
301
  return mu_state, gr.update(choices=[], value=None)
 
441
  def short_tag_from_codes(code_a: int, code_b: int) -> str:
442
  return f"<{code_a:04d}-{code_b:04d}>"
443
 
 
444
  def infer_understanding(mu_media, mu_query, mu_state):
445
  model = get_qwen()
446
 
 
501
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
502
  )[0]
503
 
 
504
  def infer_seg(media, query):
505
  model = get_qwen()
506
  vq_sam2 = load_vq_sam2()
 
642
  gr.HTML(HEADER)
643
 
644
  with gr.Tab('Mask Generation'):
645
+ download_btn_1 = gr.DownloadButton(label='­ƒôª Download', interactive=False, render=False)
646
  msk_1 = gr.AnnotatedImage(label='De-tokenized 2D masks', color_map=color_map, render=False)
647
  ans_1 = gr.HighlightedText(
648
  label='Model Response', color_map=color_map_light, show_inline_category=False, render=False)
 
661
  )
662
 
663
  with gr.Row():
664
+ random_btn_1 = gr.Button(value='­ƒö« Random', visible=False)
665
 
666
+ reset_btn_1 = gr.ClearButton([media_1, query_1, msk_1, ans_1], value='­ƒùæ´©Å Reset')
667
  reset_btn_1.click(reset_seg, None, [sample_frames_1, download_btn_1])
668
 
669
  download_btn_1.render()
670
 
671
+ submit_btn_1 = gr.Button(value='­ƒÜÇ Submit', variant='primary', elem_id='submit_1')
672
 
673
  with gr.Column():
674
  msk_1.render()
 
694
  )
695
  with gr.Tab("Mask Understanding"):
696
  MU_INSTRUCTIONS = """
697
+ ### Mask Understanding ÔÇö Instructions
698
 
699
  1. **Upload an image.**
700
  2. **Create a region mask**
 
789
  return demo
790
 
791
  if __name__ == '__main__':
792
+ # Warm-up: load all heavy models once at startup (dedicated GPU server)
793
+ get_qwen()
794
+ get_sam()
795
+ load_vq_sam2()
796
 
797
+ demo = build_demo()
798
  demo.queue()
799
+ port = int(os.getenv("PORT", "7860"))
800
+ demo.launch(server_name="0.0.0.0", server_port=port)