taaaranis commited on
Commit
5797b3f
·
verified ·
1 Parent(s): 4f3033e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -66
app.py CHANGED
@@ -1,27 +1,51 @@
1
  import gradio as gr
2
  import numpy as np
3
  import cv2
4
- import os, json, re, base64, requests
5
  from typing import List, Dict
6
  from huggingface_hub import hf_hub_download, list_repo_files
7
  from ultralytics import YOLO
 
 
 
 
 
 
 
 
8
 
9
  # ---------- Config ----------
10
  FATHOM_REPO = os.getenv("FathomNet/fathomnet2023-comp-baseline", "FathomNet/fathomnet2023-comp-baseline")
11
  FATHOM_PREF = ["fathomnet23-comp-baseline.pt", "best.pt", "yolov8m.pt"]
12
-
13
- # 触发大模型兜底的阈值(根据你的要求,保持 0.80 不变)
14
  CONF_LOW = float(os.getenv("CONF_LOW", "0.8"))
15
 
16
- # LLM(通过 Hugging Face Inference API 调用)
17
- HF_TOKEN = os.getenv("HF_TOKEN") # 必须:在 Space Secrets 里配置
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- # 【最终修改】更换为在免费API上稳定可靠的VQA模型,解决404问题
20
- LLM_MODEL_ID = (os.getenv("LLM_MODEL_ID", "").strip()
21
- or "dandelin/vilt-b32-finetuned-vqa")
22
 
23
  # ---------- Utils ----------
24
  def _resolve_weight(repo_id: str, prefer: List[str]) -> str:
 
25
  for fname in prefer:
26
  try:
27
  return hf_hub_download(repo_id=repo_id, filename=fname)
@@ -33,6 +57,7 @@ def _resolve_weight(repo_id: str, prefer: List[str]) -> str:
33
  raise RuntimeError(f"No .pt weights found in repo: {repo_id}")
34
 
35
  def _resize_limit_max_side(img_bgr: np.ndarray, max_side: int = 1280) -> np.ndarray:
 
36
  h, w = img_bgr.shape[:2]
37
  m = max(h, w)
38
  if m <= max_side:
@@ -41,7 +66,7 @@ def _resize_limit_max_side(img_bgr: np.ndarray, max_side: int = 1280) -> np.ndar
41
  return cv2.resize(img_bgr, (int(w*scale), int(h*scale)), interpolation=cv2.INTER_AREA)
42
 
43
  def uw_preprocess_bgr(img_bgr: np.ndarray) -> np.ndarray:
44
- # 轻量水下增强(清晰照片可关闭)
45
  lab = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2LAB)
46
  l, a, b = cv2.split(lab); l = cv2.equalizeHist(l)
47
  img_bgr = cv2.cvtColor(cv2.merge([l, a, b]), cv2.COLOR_LAB2BGR)
@@ -54,6 +79,7 @@ def uw_preprocess_bgr(img_bgr: np.ndarray) -> np.ndarray:
54
  return np.uint8(np.clip(np.dstack([bch, gch, rch]) * t, 0, 255))
55
 
56
  def _parse_yolo_result(ultra_res, names: Dict[int, str]):
 
57
  dets = []
58
  for b in ultra_res.boxes:
59
  cls_id = int(b.cls.item())
@@ -64,67 +90,39 @@ def _parse_yolo_result(ultra_res, names: Dict[int, str]):
64
  return dets
65
 
66
  def _best_box(dets):
 
67
  if not dets: return None
68
  return max(dets, key=lambda d: d["conf"])
69
 
70
  def _crop_xyxy(img: np.ndarray, box, pad: int = 4) -> np.ndarray:
 
71
  h, w = img.shape[:2]
72
  x1, y1, x2, y2 = [int(round(v)) for v in box]
73
  x1 = max(0, x1 - pad); y1 = max(0, y1 - pad)
74
  x2 = min(w - 1, x2 + pad); y2 = min(h - 1, y2 + pad)
75
  return img[y1:y2, x1:x2, :]
76
 
77
- def _jpeg_bytes_from_bgr(bgr: np.ndarray, max_side: int = 768) -> bytes:
78
- bgr = _resize_limit_max_side(bgr, max_side=max_side)
79
- ok, buf = cv2.imencode(".jpg", bgr, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
80
- return buf.tobytes() if ok else b""
81
-
82
- # ---------- Load primary detector ----------
83
- print("[init] loading FathomNet baseline ...")
84
- FATHOM_W = _resolve_weight(FATHOM_REPO, FATHOM_PREF)
85
- FATHOM = YOLO(FATHOM_W)
86
 
87
- # ---------- LLM fallback (VQA model via Inference API) ----------
88
-
89
- # 【最终修改】重写API调用函数,以适配稳定VQA模型的接口和返回格式
90
- def _call_vision_model_api(model_id: str, img_bytes: bytes, question: str) -> Dict:
91
- """调用Hugging Face上的视觉问答(VQA)模型"""
92
- url = f"https://api-inference.huggingface.co/models/{model_id}"
93
- img_b64 = base64.b64encode(img_bytes).decode("utf-8")
94
- headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
95
-
96
- # VQA模型使用不同的请求体结构
97
- payload = {"inputs": {"question": question, "image": img_b64}}
98
-
99
- r = requests.post(url, headers=headers, json=payload, timeout=60)
100
- r.raise_for_status()
101
-
102
- # VQA模型的返回格式: [{"score": 0.99, "answer": "..."}, ...]
103
- results = r.json()
104
- if results and isinstance(results, list) and results[0]:
105
- top_result = results[0]
106
- return {"label": top_result.get("answer", "vqa_parse_error"),
107
- "conf": top_result.get("score", 0.5)}
108
- else:
109
- return {"label": "vqa_empty_result", "conf": 0.5}
110
-
111
- # 【最终修改】重写LLM备用逻辑,使其调用新的VQA函数
112
  def llm_fallback(img_bgr: np.ndarray) -> Dict:
113
- """当主模型信心不足时,调用VQA模型进行识别。"""
114
- if not HF_TOKEN:
115
- return {"label": "unknown", "conf": 0.51, "xyxy": None, "note": "HF_TOKEN not set"}
116
-
117
- jb = _jpeg_bytes_from_bgr(img_bgr, max_side=768)
118
- # 构造一个适合VQA模型的“问题”
119
- prompt_text = "What is the single, most prominent marine species in this image?"
120
-
121
  try:
122
- result = _call_vision_model_api(LLM_MODEL_ID, jb, prompt_text)
123
- result["xyxy"] = None # VQA模型本身不提供边界框坐标
124
- return result
 
 
 
 
 
 
 
 
 
125
  except Exception as e:
126
- error_note = str(e).replace(HF_TOKEN, "HF_TOKEN_***") if HF_TOKEN else str(e)
127
- return {"label": "unknown", "conf": 0.51, "xyxy": None, "note": f"LLM error: {error_note}"}
128
 
129
 
130
  # ---------- Inference ----------
@@ -151,14 +149,12 @@ def predict(
151
  use_llm = (len(dets) == 0) or (max([d["conf"] for d in dets]) < CONF_LOW if dets else True)
152
 
153
  if use_llm:
154
- print("[info] Low confidence or no detection, triggering LLM fallback...")
155
  roi_img = bgr
156
  best = _best_box(dets)
157
  if best is not None:
158
  roi_img = _crop_xyxy(bgr, best["xyxy"])
159
-
160
  llm_det = llm_fallback(roi_img)
161
-
162
  if best is not None and llm_det.get("xyxy") is None:
163
  llm_det["xyxy"] = best["xyxy"]
164
  enhanced = [llm_det]
@@ -172,7 +168,6 @@ def predict(
172
  score = d["conf"]
173
  xyxy = d.get("xyxy")
174
  note = d.get("note")
175
-
176
  if xyxy and isinstance(xyxy, list) and len(xyxy) == 4:
177
  x1, y1, x2, y2 = map(int, xyxy)
178
  color = (0, 255, 0) if not use_llm else (255, 165, 0) # LLM用橙色框
@@ -180,7 +175,7 @@ def predict(
180
  cv2.putText(vis, f"{label_show} {score:.2f}", (x1, max(12, y1-6)),
181
  cv2.FONT_HERSHEY_SIMPLEX, 0.55, color, 2)
182
  else:
183
- text_to_show = f"{label_show} {score:.2f} (LLM)"
184
  if note:
185
  text_to_show = note
186
  cv2.putText(vis, text_to_show, (12, 24),
@@ -198,11 +193,11 @@ def predict(
198
  return None, [{"error": str(e)}]
199
 
200
  # ---------- Gradio UI ----------
201
- with gr.Blocks(title="Marine Species ID – YOLO primary + VQA fallback") as demo:
202
  gr.Markdown(
203
- "### Marine Species Identification (no training)\n"
204
  "- **Primary**: FathomNet 2023 Baseline (YOLOv8m)\n"
205
- f"- **Fallback**: Stable VQA Model via Hugging Face Inference API (triggered when max conf < {CONF_LOW} or no boxes)"
206
  )
207
  with gr.Row():
208
  with gr.Column(scale=5):
@@ -216,8 +211,9 @@ with gr.Blocks(title="Marine Species ID – YOLO primary + VQA fallback") as dem
216
  img_out = gr.Image(label="Detections", interactive=False)
217
  json_out = gr.JSON(label="Detections JSON (label/conf/xyxy)")
218
  btn.click(predict, inputs=[img_in, conf, iou, imgsz, pre], outputs=[img_out, json_out])
219
- gr.Markdown(f"Tip: add `HF_TOKEN` in Settings Repository secrets. "
220
- f"Optional envs: `CONF_LOW={CONF_LOW}`, `LLM_MODEL_ID={LLM_MODEL_ID}`.")
221
 
222
  if __name__ == "__main__":
223
  demo.launch()
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import cv2
4
+ import os, json, re, base64
5
  from typing import List, Dict
6
  from huggingface_hub import hf_hub_download, list_repo_files
7
  from ultralytics import YOLO
8
+ from PIL import Image
9
+ import torch
10
+
11
+ # 【重要】请确保你的Hugging Face Space的requirements.txt文件里包含以下库:
12
+ # transformers
13
+ # torch
14
+ # sentencepiece
15
+ # Pillow
16
 
17
  # ---------- Config ----------
18
  FATHOM_REPO = os.getenv("FathomNet/fathomnet2023-comp-baseline", "FathomNet/fathomnet2023-comp-baseline")
19
  FATHOM_PREF = ["fathomnet23-comp-baseline.pt", "best.pt", "yolov8m.pt"]
 
 
20
  CONF_LOW = float(os.getenv("CONF_LOW", "0.8"))
21
 
22
+ # ---------- Load primary detector ----------
23
+ print("[init] loading FathomNet baseline (YOLO)...")
24
+ FATHOM_W = _resolve_weight(FATHOM_REPO, FATHOM_PREF)
25
+ FATHOM = YOLO(FATHOM_W)
26
+
27
+ # ----------【最终修改】加载一个本地的、可靠的备用模型,不再依赖不稳定的外部API ----------
28
+ print("[init] loading fallback vision model (BLIP)...")
29
+ try:
30
+ from transformers import BlipProcessor, BlipForConditionalGeneration
31
+ # 检查GPU是否可用
32
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
33
+ print(f"[init] Fallback model will run on: {DEVICE}")
34
+
35
+ # 加载模型和处理器
36
+ fallback_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
37
+ fallback_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(DEVICE)
38
+ FALLBACK_MODEL_LOADED = True
39
+ except ImportError:
40
+ print("[warn] transformers, torch, or pillow not installed. Fallback model will not be available.")
41
+ print("[warn] Please add 'transformers', 'torch', 'sentencepiece', 'Pillow' to your requirements.txt")
42
+ FALLBACK_MODEL_LOADED = False
43
+ DEVICE = "cpu"
44
 
 
 
 
45
 
46
  # ---------- Utils ----------
47
  def _resolve_weight(repo_id: str, prefer: List[str]) -> str:
48
+ # ... (此函数无需修改)
49
  for fname in prefer:
50
  try:
51
  return hf_hub_download(repo_id=repo_id, filename=fname)
 
57
  raise RuntimeError(f"No .pt weights found in repo: {repo_id}")
58
 
59
  def _resize_limit_max_side(img_bgr: np.ndarray, max_side: int = 1280) -> np.ndarray:
60
+ # ... (此函数无需修改)
61
  h, w = img_bgr.shape[:2]
62
  m = max(h, w)
63
  if m <= max_side:
 
66
  return cv2.resize(img_bgr, (int(w*scale), int(h*scale)), interpolation=cv2.INTER_AREA)
67
 
68
  def uw_preprocess_bgr(img_bgr: np.ndarray) -> np.ndarray:
69
+ # ... (此函数无需修改)
70
  lab = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2LAB)
71
  l, a, b = cv2.split(lab); l = cv2.equalizeHist(l)
72
  img_bgr = cv2.cvtColor(cv2.merge([l, a, b]), cv2.COLOR_LAB2BGR)
 
79
  return np.uint8(np.clip(np.dstack([bch, gch, rch]) * t, 0, 255))
80
 
81
  def _parse_yolo_result(ultra_res, names: Dict[int, str]):
82
+ # ... (此函数无需修改)
83
  dets = []
84
  for b in ultra_res.boxes:
85
  cls_id = int(b.cls.item())
 
90
  return dets
91
 
92
  def _best_box(dets):
93
+ # ... (此函数无需修改)
94
  if not dets: return None
95
  return max(dets, key=lambda d: d["conf"])
96
 
97
  def _crop_xyxy(img: np.ndarray, box, pad: int = 4) -> np.ndarray:
98
+ # ... (此函数无需修改)
99
  h, w = img.shape[:2]
100
  x1, y1, x2, y2 = [int(round(v)) for v in box]
101
  x1 = max(0, x1 - pad); y1 = max(0, y1 - pad)
102
  x2 = min(w - 1, x2 + pad); y2 = min(h - 1, y2 + pad)
103
  return img[y1:y2, x1:x2, :]
104
 
 
 
 
 
 
 
 
 
 
105
 
106
+ # 【最终修改】重写LLM备用逻辑,使其调用加载到本地的BLIP模型
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  def llm_fallback(img_bgr: np.ndarray) -> Dict:
108
+ """当主模型信心不足时,调用本地BLIP模型生成描述。"""
109
+ if not FALLBACK_MODEL_LOADED:
110
+ return {"label": "unknown", "conf": 0.51, "xyxy": None, "note": "Fallback model not loaded"}
 
 
 
 
 
111
  try:
112
+ # 1. 将图像从OpenCV格式(BGR)转换为PIL格式(RGB)
113
+ raw_image = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
114
+ # 2. 为模型准备输入
115
+ inputs = fallback_processor(raw_image, return_tensors="pt").to(DEVICE)
116
+ # 3. 生成描述
117
+ out = fallback_model.generate(**inputs, max_new_tokens=20)
118
+ # 4. 解码并清理描述文本
119
+ caption = fallback_processor.decode(out[0], skip_special_tokens=True)
120
+ # 这是一个简单的启发式清理,可能需要根据实际情况微调
121
+ caption = caption.replace("a photograph of", "").replace("a close up of", "").strip()
122
+ # BLIP模型不提供置信度分数,我们给一个固定的值以表明这是备用模型的结果
123
+ return {"label": caption, "conf": 0.60, "xyxy": None}
124
  except Exception as e:
125
+ return {"label": "unknown", "conf": 0.51, "xyxy": None, "note": f"Fallback model error: {e}"}
 
126
 
127
 
128
  # ---------- Inference ----------
 
149
  use_llm = (len(dets) == 0) or (max([d["conf"] for d in dets]) < CONF_LOW if dets else True)
150
 
151
  if use_llm:
152
+ print("[info] Low confidence or no detection, triggering LOCAL fallback model...")
153
  roi_img = bgr
154
  best = _best_box(dets)
155
  if best is not None:
156
  roi_img = _crop_xyxy(bgr, best["xyxy"])
 
157
  llm_det = llm_fallback(roi_img)
 
158
  if best is not None and llm_det.get("xyxy") is None:
159
  llm_det["xyxy"] = best["xyxy"]
160
  enhanced = [llm_det]
 
168
  score = d["conf"]
169
  xyxy = d.get("xyxy")
170
  note = d.get("note")
 
171
  if xyxy and isinstance(xyxy, list) and len(xyxy) == 4:
172
  x1, y1, x2, y2 = map(int, xyxy)
173
  color = (0, 255, 0) if not use_llm else (255, 165, 0) # LLM用橙色框
 
175
  cv2.putText(vis, f"{label_show} {score:.2f}", (x1, max(12, y1-6)),
176
  cv2.FONT_HERSHEY_SIMPLEX, 0.55, color, 2)
177
  else:
178
+ text_to_show = f"{label_show} {score:.2f} (Fallback Model)"
179
  if note:
180
  text_to_show = note
181
  cv2.putText(vis, text_to_show, (12, 24),
 
193
  return None, [{"error": str(e)}]
194
 
195
  # ---------- Gradio UI ----------
196
+ with gr.Blocks(title="Marine Species ID – YOLO primary + Local Fallback") as demo:
197
  gr.Markdown(
198
+ "### Marine Species Identification (with Self-Contained Fallback)\n"
199
  "- **Primary**: FathomNet 2023 Baseline (YOLOv8m)\n"
200
+ f"- **Fallback**: Local BLIP Model (triggered when max conf < {CONF_LOW} or no boxes)"
201
  )
202
  with gr.Row():
203
  with gr.Column(scale=5):
 
211
  img_out = gr.Image(label="Detections", interactive=False)
212
  json_out = gr.JSON(label="Detections JSON (label/conf/xyxy)")
213
  btn.click(predict, inputs=[img_in, conf, iou, imgsz, pre], outputs=[img_out, json_out])
214
+ gr.Markdown("Tip: This app is now self-contained and does not require an HF_TOKEN. "
215
+ "The first launch may be slow due to model downloads.")
216
 
217
  if __name__ == "__main__":
218
  demo.launch()
219
+