Update app.py
Browse files
app.py
CHANGED
|
@@ -1,27 +1,51 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import numpy as np
|
| 3 |
import cv2
|
| 4 |
-
import os, json, re, base64
|
| 5 |
from typing import List, Dict
|
| 6 |
from huggingface_hub import hf_hub_download, list_repo_files
|
| 7 |
from ultralytics import YOLO
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
# ---------- Config ----------
|
| 10 |
FATHOM_REPO = os.getenv("FathomNet/fathomnet2023-comp-baseline", "FathomNet/fathomnet2023-comp-baseline")
|
| 11 |
FATHOM_PREF = ["fathomnet23-comp-baseline.pt", "best.pt", "yolov8m.pt"]
|
| 12 |
-
|
| 13 |
-
# 触发大模型兜底的阈值(根据你的要求,保持 0.80 不变)
|
| 14 |
CONF_LOW = float(os.getenv("CONF_LOW", "0.8"))
|
| 15 |
|
| 16 |
-
#
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
-
# 【最终修改】更换为在免费API上稳定可靠的VQA模型,解决404问题
|
| 20 |
-
LLM_MODEL_ID = (os.getenv("LLM_MODEL_ID", "").strip()
|
| 21 |
-
or "dandelin/vilt-b32-finetuned-vqa")
|
| 22 |
|
| 23 |
# ---------- Utils ----------
|
| 24 |
def _resolve_weight(repo_id: str, prefer: List[str]) -> str:
|
|
|
|
| 25 |
for fname in prefer:
|
| 26 |
try:
|
| 27 |
return hf_hub_download(repo_id=repo_id, filename=fname)
|
|
@@ -33,6 +57,7 @@ def _resolve_weight(repo_id: str, prefer: List[str]) -> str:
|
|
| 33 |
raise RuntimeError(f"No .pt weights found in repo: {repo_id}")
|
| 34 |
|
| 35 |
def _resize_limit_max_side(img_bgr: np.ndarray, max_side: int = 1280) -> np.ndarray:
|
|
|
|
| 36 |
h, w = img_bgr.shape[:2]
|
| 37 |
m = max(h, w)
|
| 38 |
if m <= max_side:
|
|
@@ -41,7 +66,7 @@ def _resize_limit_max_side(img_bgr: np.ndarray, max_side: int = 1280) -> np.ndar
|
|
| 41 |
return cv2.resize(img_bgr, (int(w*scale), int(h*scale)), interpolation=cv2.INTER_AREA)
|
| 42 |
|
| 43 |
def uw_preprocess_bgr(img_bgr: np.ndarray) -> np.ndarray:
|
| 44 |
-
#
|
| 45 |
lab = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2LAB)
|
| 46 |
l, a, b = cv2.split(lab); l = cv2.equalizeHist(l)
|
| 47 |
img_bgr = cv2.cvtColor(cv2.merge([l, a, b]), cv2.COLOR_LAB2BGR)
|
|
@@ -54,6 +79,7 @@ def uw_preprocess_bgr(img_bgr: np.ndarray) -> np.ndarray:
|
|
| 54 |
return np.uint8(np.clip(np.dstack([bch, gch, rch]) * t, 0, 255))
|
| 55 |
|
| 56 |
def _parse_yolo_result(ultra_res, names: Dict[int, str]):
|
|
|
|
| 57 |
dets = []
|
| 58 |
for b in ultra_res.boxes:
|
| 59 |
cls_id = int(b.cls.item())
|
|
@@ -64,67 +90,39 @@ def _parse_yolo_result(ultra_res, names: Dict[int, str]):
|
|
| 64 |
return dets
|
| 65 |
|
| 66 |
def _best_box(dets):
|
|
|
|
| 67 |
if not dets: return None
|
| 68 |
return max(dets, key=lambda d: d["conf"])
|
| 69 |
|
| 70 |
def _crop_xyxy(img: np.ndarray, box, pad: int = 4) -> np.ndarray:
|
|
|
|
| 71 |
h, w = img.shape[:2]
|
| 72 |
x1, y1, x2, y2 = [int(round(v)) for v in box]
|
| 73 |
x1 = max(0, x1 - pad); y1 = max(0, y1 - pad)
|
| 74 |
x2 = min(w - 1, x2 + pad); y2 = min(h - 1, y2 + pad)
|
| 75 |
return img[y1:y2, x1:x2, :]
|
| 76 |
|
| 77 |
-
def _jpeg_bytes_from_bgr(bgr: np.ndarray, max_side: int = 768) -> bytes:
|
| 78 |
-
bgr = _resize_limit_max_side(bgr, max_side=max_side)
|
| 79 |
-
ok, buf = cv2.imencode(".jpg", bgr, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
|
| 80 |
-
return buf.tobytes() if ok else b""
|
| 81 |
-
|
| 82 |
-
# ---------- Load primary detector ----------
|
| 83 |
-
print("[init] loading FathomNet baseline ...")
|
| 84 |
-
FATHOM_W = _resolve_weight(FATHOM_REPO, FATHOM_PREF)
|
| 85 |
-
FATHOM = YOLO(FATHOM_W)
|
| 86 |
|
| 87 |
-
#
|
| 88 |
-
|
| 89 |
-
# 【最终修改】重写API调用函数,以适配稳定VQA模型的接口和返回格式
|
| 90 |
-
def _call_vision_model_api(model_id: str, img_bytes: bytes, question: str) -> Dict:
|
| 91 |
-
"""调用Hugging Face上的视觉问答(VQA)模型"""
|
| 92 |
-
url = f"https://api-inference.huggingface.co/models/{model_id}"
|
| 93 |
-
img_b64 = base64.b64encode(img_bytes).decode("utf-8")
|
| 94 |
-
headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
|
| 95 |
-
|
| 96 |
-
# VQA模型使用不同的请求体结构
|
| 97 |
-
payload = {"inputs": {"question": question, "image": img_b64}}
|
| 98 |
-
|
| 99 |
-
r = requests.post(url, headers=headers, json=payload, timeout=60)
|
| 100 |
-
r.raise_for_status()
|
| 101 |
-
|
| 102 |
-
# VQA模型的返回格式: [{"score": 0.99, "answer": "..."}, ...]
|
| 103 |
-
results = r.json()
|
| 104 |
-
if results and isinstance(results, list) and results[0]:
|
| 105 |
-
top_result = results[0]
|
| 106 |
-
return {"label": top_result.get("answer", "vqa_parse_error"),
|
| 107 |
-
"conf": top_result.get("score", 0.5)}
|
| 108 |
-
else:
|
| 109 |
-
return {"label": "vqa_empty_result", "conf": 0.5}
|
| 110 |
-
|
| 111 |
-
# 【最终修改】重写LLM备用逻辑,使其调用新的VQA函数
|
| 112 |
def llm_fallback(img_bgr: np.ndarray) -> Dict:
|
| 113 |
-
"""当主模型信心不足时,调用
|
| 114 |
-
if not
|
| 115 |
-
return {"label": "unknown", "conf": 0.51, "xyxy": None, "note": "
|
| 116 |
-
|
| 117 |
-
jb = _jpeg_bytes_from_bgr(img_bgr, max_side=768)
|
| 118 |
-
# 构造一个适合VQA模型的“问题”
|
| 119 |
-
prompt_text = "What is the single, most prominent marine species in this image?"
|
| 120 |
-
|
| 121 |
try:
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
except Exception as e:
|
| 126 |
-
|
| 127 |
-
return {"label": "unknown", "conf": 0.51, "xyxy": None, "note": f"LLM error: {error_note}"}
|
| 128 |
|
| 129 |
|
| 130 |
# ---------- Inference ----------
|
|
@@ -151,14 +149,12 @@ def predict(
|
|
| 151 |
use_llm = (len(dets) == 0) or (max([d["conf"] for d in dets]) < CONF_LOW if dets else True)
|
| 152 |
|
| 153 |
if use_llm:
|
| 154 |
-
print("[info] Low confidence or no detection, triggering
|
| 155 |
roi_img = bgr
|
| 156 |
best = _best_box(dets)
|
| 157 |
if best is not None:
|
| 158 |
roi_img = _crop_xyxy(bgr, best["xyxy"])
|
| 159 |
-
|
| 160 |
llm_det = llm_fallback(roi_img)
|
| 161 |
-
|
| 162 |
if best is not None and llm_det.get("xyxy") is None:
|
| 163 |
llm_det["xyxy"] = best["xyxy"]
|
| 164 |
enhanced = [llm_det]
|
|
@@ -172,7 +168,6 @@ def predict(
|
|
| 172 |
score = d["conf"]
|
| 173 |
xyxy = d.get("xyxy")
|
| 174 |
note = d.get("note")
|
| 175 |
-
|
| 176 |
if xyxy and isinstance(xyxy, list) and len(xyxy) == 4:
|
| 177 |
x1, y1, x2, y2 = map(int, xyxy)
|
| 178 |
color = (0, 255, 0) if not use_llm else (255, 165, 0) # LLM用橙色框
|
|
@@ -180,7 +175,7 @@ def predict(
|
|
| 180 |
cv2.putText(vis, f"{label_show} {score:.2f}", (x1, max(12, y1-6)),
|
| 181 |
cv2.FONT_HERSHEY_SIMPLEX, 0.55, color, 2)
|
| 182 |
else:
|
| 183 |
-
text_to_show = f"{label_show} {score:.2f} (
|
| 184 |
if note:
|
| 185 |
text_to_show = note
|
| 186 |
cv2.putText(vis, text_to_show, (12, 24),
|
|
@@ -198,11 +193,11 @@ def predict(
|
|
| 198 |
return None, [{"error": str(e)}]
|
| 199 |
|
| 200 |
# ---------- Gradio UI ----------
|
| 201 |
-
with gr.Blocks(title="Marine Species ID – YOLO primary +
|
| 202 |
gr.Markdown(
|
| 203 |
-
"### Marine Species Identification (
|
| 204 |
"- **Primary**: FathomNet 2023 Baseline (YOLOv8m)\n"
|
| 205 |
-
f"- **Fallback**:
|
| 206 |
)
|
| 207 |
with gr.Row():
|
| 208 |
with gr.Column(scale=5):
|
|
@@ -216,8 +211,9 @@ with gr.Blocks(title="Marine Species ID – YOLO primary + VQA fallback") as dem
|
|
| 216 |
img_out = gr.Image(label="Detections", interactive=False)
|
| 217 |
json_out = gr.JSON(label="Detections JSON (label/conf/xyxy)")
|
| 218 |
btn.click(predict, inputs=[img_in, conf, iou, imgsz, pre], outputs=[img_out, json_out])
|
| 219 |
-
gr.Markdown(
|
| 220 |
-
|
| 221 |
|
| 222 |
if __name__ == "__main__":
|
| 223 |
demo.launch()
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import numpy as np
|
| 3 |
import cv2
|
| 4 |
+
import os, json, re, base64
|
| 5 |
from typing import List, Dict
|
| 6 |
from huggingface_hub import hf_hub_download, list_repo_files
|
| 7 |
from ultralytics import YOLO
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
# 【重要】请确保你的Hugging Face Space的requirements.txt文件里包含以下库:
|
| 12 |
+
# transformers
|
| 13 |
+
# torch
|
| 14 |
+
# sentencepiece
|
| 15 |
+
# Pillow
|
| 16 |
|
| 17 |
# ---------- Config ----------
|
| 18 |
FATHOM_REPO = os.getenv("FathomNet/fathomnet2023-comp-baseline", "FathomNet/fathomnet2023-comp-baseline")
|
| 19 |
FATHOM_PREF = ["fathomnet23-comp-baseline.pt", "best.pt", "yolov8m.pt"]
|
|
|
|
|
|
|
| 20 |
CONF_LOW = float(os.getenv("CONF_LOW", "0.8"))
|
| 21 |
|
| 22 |
+
# ---------- Load primary detector ----------
|
| 23 |
+
print("[init] loading FathomNet baseline (YOLO)...")
|
| 24 |
+
FATHOM_W = _resolve_weight(FATHOM_REPO, FATHOM_PREF)
|
| 25 |
+
FATHOM = YOLO(FATHOM_W)
|
| 26 |
+
|
| 27 |
+
# ----------【最终修改】加载一个本地的、可靠的备用模型,不再依赖不稳定的外部API ----------
|
| 28 |
+
print("[init] loading fallback vision model (BLIP)...")
|
| 29 |
+
try:
|
| 30 |
+
from transformers import BlipProcessor, BlipForConditionalGeneration
|
| 31 |
+
# 检查GPU是否可用
|
| 32 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 33 |
+
print(f"[init] Fallback model will run on: {DEVICE}")
|
| 34 |
+
|
| 35 |
+
# 加载模型和处理器
|
| 36 |
+
fallback_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
|
| 37 |
+
fallback_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(DEVICE)
|
| 38 |
+
FALLBACK_MODEL_LOADED = True
|
| 39 |
+
except ImportError:
|
| 40 |
+
print("[warn] transformers, torch, or pillow not installed. Fallback model will not be available.")
|
| 41 |
+
print("[warn] Please add 'transformers', 'torch', 'sentencepiece', 'Pillow' to your requirements.txt")
|
| 42 |
+
FALLBACK_MODEL_LOADED = False
|
| 43 |
+
DEVICE = "cpu"
|
| 44 |
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
# ---------- Utils ----------
|
| 47 |
def _resolve_weight(repo_id: str, prefer: List[str]) -> str:
|
| 48 |
+
# ... (此函数无需修改)
|
| 49 |
for fname in prefer:
|
| 50 |
try:
|
| 51 |
return hf_hub_download(repo_id=repo_id, filename=fname)
|
|
|
|
| 57 |
raise RuntimeError(f"No .pt weights found in repo: {repo_id}")
|
| 58 |
|
| 59 |
def _resize_limit_max_side(img_bgr: np.ndarray, max_side: int = 1280) -> np.ndarray:
|
| 60 |
+
# ... (此函数无需修改)
|
| 61 |
h, w = img_bgr.shape[:2]
|
| 62 |
m = max(h, w)
|
| 63 |
if m <= max_side:
|
|
|
|
| 66 |
return cv2.resize(img_bgr, (int(w*scale), int(h*scale)), interpolation=cv2.INTER_AREA)
|
| 67 |
|
| 68 |
def uw_preprocess_bgr(img_bgr: np.ndarray) -> np.ndarray:
|
| 69 |
+
# ... (此函数无需修改)
|
| 70 |
lab = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2LAB)
|
| 71 |
l, a, b = cv2.split(lab); l = cv2.equalizeHist(l)
|
| 72 |
img_bgr = cv2.cvtColor(cv2.merge([l, a, b]), cv2.COLOR_LAB2BGR)
|
|
|
|
| 79 |
return np.uint8(np.clip(np.dstack([bch, gch, rch]) * t, 0, 255))
|
| 80 |
|
| 81 |
def _parse_yolo_result(ultra_res, names: Dict[int, str]):
|
| 82 |
+
# ... (此函数无需修改)
|
| 83 |
dets = []
|
| 84 |
for b in ultra_res.boxes:
|
| 85 |
cls_id = int(b.cls.item())
|
|
|
|
| 90 |
return dets
|
| 91 |
|
| 92 |
def _best_box(dets):
|
| 93 |
+
# ... (此函数无需修改)
|
| 94 |
if not dets: return None
|
| 95 |
return max(dets, key=lambda d: d["conf"])
|
| 96 |
|
| 97 |
def _crop_xyxy(img: np.ndarray, box, pad: int = 4) -> np.ndarray:
|
| 98 |
+
# ... (此函数无需修改)
|
| 99 |
h, w = img.shape[:2]
|
| 100 |
x1, y1, x2, y2 = [int(round(v)) for v in box]
|
| 101 |
x1 = max(0, x1 - pad); y1 = max(0, y1 - pad)
|
| 102 |
x2 = min(w - 1, x2 + pad); y2 = min(h - 1, y2 + pad)
|
| 103 |
return img[y1:y2, x1:x2, :]
|
| 104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
+
# 【最终修改】重写LLM备用逻辑,使其调用加载到本地的BLIP模型
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
def llm_fallback(img_bgr: np.ndarray) -> Dict:
|
| 108 |
+
"""当主模型信心不足时,调用本地BLIP模型生成描述。"""
|
| 109 |
+
if not FALLBACK_MODEL_LOADED:
|
| 110 |
+
return {"label": "unknown", "conf": 0.51, "xyxy": None, "note": "Fallback model not loaded"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
try:
|
| 112 |
+
# 1. 将图像从OpenCV格式(BGR)转换为PIL格式(RGB)
|
| 113 |
+
raw_image = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
|
| 114 |
+
# 2. 为模型准备输入
|
| 115 |
+
inputs = fallback_processor(raw_image, return_tensors="pt").to(DEVICE)
|
| 116 |
+
# 3. 生成描述
|
| 117 |
+
out = fallback_model.generate(**inputs, max_new_tokens=20)
|
| 118 |
+
# 4. 解码并清理描述文本
|
| 119 |
+
caption = fallback_processor.decode(out[0], skip_special_tokens=True)
|
| 120 |
+
# 这是一个简单的启发式清理,可能需要根据实际情况微调
|
| 121 |
+
caption = caption.replace("a photograph of", "").replace("a close up of", "").strip()
|
| 122 |
+
# BLIP模型不提供置信度分数,我们给一个固定的值以表明这是备用模型的结果
|
| 123 |
+
return {"label": caption, "conf": 0.60, "xyxy": None}
|
| 124 |
except Exception as e:
|
| 125 |
+
return {"label": "unknown", "conf": 0.51, "xyxy": None, "note": f"Fallback model error: {e}"}
|
|
|
|
| 126 |
|
| 127 |
|
| 128 |
# ---------- Inference ----------
|
|
|
|
| 149 |
use_llm = (len(dets) == 0) or (max([d["conf"] for d in dets]) < CONF_LOW if dets else True)
|
| 150 |
|
| 151 |
if use_llm:
|
| 152 |
+
print("[info] Low confidence or no detection, triggering LOCAL fallback model...")
|
| 153 |
roi_img = bgr
|
| 154 |
best = _best_box(dets)
|
| 155 |
if best is not None:
|
| 156 |
roi_img = _crop_xyxy(bgr, best["xyxy"])
|
|
|
|
| 157 |
llm_det = llm_fallback(roi_img)
|
|
|
|
| 158 |
if best is not None and llm_det.get("xyxy") is None:
|
| 159 |
llm_det["xyxy"] = best["xyxy"]
|
| 160 |
enhanced = [llm_det]
|
|
|
|
| 168 |
score = d["conf"]
|
| 169 |
xyxy = d.get("xyxy")
|
| 170 |
note = d.get("note")
|
|
|
|
| 171 |
if xyxy and isinstance(xyxy, list) and len(xyxy) == 4:
|
| 172 |
x1, y1, x2, y2 = map(int, xyxy)
|
| 173 |
color = (0, 255, 0) if not use_llm else (255, 165, 0) # LLM用橙色框
|
|
|
|
| 175 |
cv2.putText(vis, f"{label_show} {score:.2f}", (x1, max(12, y1-6)),
|
| 176 |
cv2.FONT_HERSHEY_SIMPLEX, 0.55, color, 2)
|
| 177 |
else:
|
| 178 |
+
text_to_show = f"{label_show} {score:.2f} (Fallback Model)"
|
| 179 |
if note:
|
| 180 |
text_to_show = note
|
| 181 |
cv2.putText(vis, text_to_show, (12, 24),
|
|
|
|
| 193 |
return None, [{"error": str(e)}]
|
| 194 |
|
| 195 |
# ---------- Gradio UI ----------
|
| 196 |
+
with gr.Blocks(title="Marine Species ID – YOLO primary + Local Fallback") as demo:
|
| 197 |
gr.Markdown(
|
| 198 |
+
"### Marine Species Identification (with Self-Contained Fallback)\n"
|
| 199 |
"- **Primary**: FathomNet 2023 Baseline (YOLOv8m)\n"
|
| 200 |
+
f"- **Fallback**: Local BLIP Model (triggered when max conf < {CONF_LOW} or no boxes)"
|
| 201 |
)
|
| 202 |
with gr.Row():
|
| 203 |
with gr.Column(scale=5):
|
|
|
|
| 211 |
img_out = gr.Image(label="Detections", interactive=False)
|
| 212 |
json_out = gr.JSON(label="Detections JSON (label/conf/xyxy)")
|
| 213 |
btn.click(predict, inputs=[img_in, conf, iou, imgsz, pre], outputs=[img_out, json_out])
|
| 214 |
+
gr.Markdown("Tip: This app is now self-contained and does not require an HF_TOKEN. "
|
| 215 |
+
"The first launch may be slow due to model downloads.")
|
| 216 |
|
| 217 |
if __name__ == "__main__":
|
| 218 |
demo.launch()
|
| 219 |
+
|