Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
|
@@ -7,19 +7,29 @@ import numpy as np
|
|
| 7 |
import requests
|
| 8 |
from fastapi import FastAPI, BackgroundTasks, Header, HTTPException, Query
|
| 9 |
from pydantic import BaseModel, Field
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
# ---------- logging ----------
|
| 14 |
logging.basicConfig(level=logging.INFO, format="%(levelname)s:%(name)s:%(message)s")
|
| 15 |
LOG = logging.getLogger("remote_indexer")
|
| 16 |
|
| 17 |
# ---------- ENV (config) ----------
|
| 18 |
-
#
|
|
|
|
|
|
|
|
|
|
| 19 |
DEFAULT_BACKENDS = "deepinfra,hf"
|
| 20 |
-
EMB_BACKEND_ORDER = [s.strip().lower()
|
|
|
|
|
|
|
| 21 |
|
| 22 |
-
# Auto-fallback vers DeepInfra si HF renvoie la fameuse erreur Similarity
|
| 23 |
ALLOW_DI_AUTOFALLBACK = os.getenv("ALLOW_DI_AUTOFALLBACK", "true").lower() in ("1","true","yes","on")
|
| 24 |
|
| 25 |
# HF Inference API
|
|
@@ -42,39 +52,133 @@ HF_TIMEOUT = float(os.getenv("EMB_TIMEOUT_SEC", "120"))
|
|
| 42 |
HF_WAIT = os.getenv("HF_WAIT_FOR_MODEL", "true").lower() in ("1","true","yes","on")
|
| 43 |
HF_PIPELINE_FIRST = os.getenv("HF_PIPELINE_FIRST", "true").lower() in ("1","true","yes","on")
|
| 44 |
|
| 45 |
-
# DeepInfra
|
| 46 |
DI_TOKEN = os.getenv("DEEPINFRA_API_KEY", "").strip()
|
| 47 |
-
# 👇 IMPORTANT : modèle existant chez DeepInfra (multilingue)
|
| 48 |
DI_MODEL = os.getenv("DEEPINFRA_EMBED_MODEL", "BAAI/bge-m3").strip()
|
| 49 |
-
# 👇 IMPORTANT : endpoint OpenAI-compatible
|
| 50 |
DI_URL = os.getenv("DEEPINFRA_EMBED_URL", "https://api.deepinfra.com/v1/openai/embeddings").strip()
|
| 51 |
DI_TIMEOUT = float(os.getenv("EMB_TIMEOUT_SEC", "120"))
|
| 52 |
|
| 53 |
-
# Retries
|
| 54 |
RETRY_MAX = int(os.getenv("EMB_RETRY_MAX", "6"))
|
| 55 |
RETRY_BASE_SEC = float(os.getenv("EMB_RETRY_BASE", "1.5"))
|
| 56 |
RETRY_JITTER = float(os.getenv("EMB_RETRY_JITTER", "0.35"))
|
| 57 |
|
| 58 |
# Qdrant
|
| 59 |
-
QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333")
|
| 60 |
QDRANT_API = os.getenv("QDRANT_API_KEY", "").strip()
|
| 61 |
|
| 62 |
-
# Auth
|
| 63 |
AUTH_TOKEN = os.getenv("REMOTE_INDEX_TOKEN", "").strip()
|
| 64 |
|
| 65 |
LOG.info(f"Embeddings backend order = {EMB_BACKEND_ORDER}")
|
| 66 |
LOG.info(f"HF pipeline URL = {HF_URL_PIPELINE}")
|
| 67 |
LOG.info(f"HF models URL = {HF_URL_MODELS}")
|
|
|
|
| 68 |
if "hf" in EMB_BACKEND_ORDER and not HF_TOKEN:
|
| 69 |
LOG.warning("HF_API_TOKEN manquant — tentatives HF échoueront.")
|
| 70 |
if "deepinfra" in EMB_BACKEND_ORDER and not DI_TOKEN:
|
| 71 |
LOG.warning("DEEPINFRA_API_KEY manquant — tentatives DeepInfra échoueront.")
|
| 72 |
|
| 73 |
-
# ----------
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
# ---------- Pydantic ----------
|
| 80 |
class FileIn(BaseModel):
|
|
@@ -142,13 +246,11 @@ def _hf_http(url: str, payload: Dict[str, Any], headers_extra: Optional[Dict[str
|
|
| 142 |
|
| 143 |
data = r.json()
|
| 144 |
arr = np.array(data, dtype=np.float32)
|
| 145 |
-
if arr.ndim == 3:
|
| 146 |
arr = arr.mean(axis=1)
|
| 147 |
-
elif arr.ndim ==
|
| 148 |
-
pass
|
| 149 |
-
elif arr.ndim == 1: # [dim] -> [1, dim]
|
| 150 |
arr = arr.reshape(1, -1)
|
| 151 |
-
|
| 152 |
raise RuntimeError(f"HF: unexpected embeddings shape: {arr.shape}")
|
| 153 |
|
| 154 |
norms = np.linalg.norm(arr, axis=1, keepdims=True) + 1e-12
|
|
@@ -159,7 +261,6 @@ def _hf_post_embeddings_once(batch: List[str]) -> Tuple[np.ndarray, int]:
|
|
| 159 |
payload: Dict[str, Any] = {"inputs": (batch if len(batch) > 1 else batch[0])}
|
| 160 |
urls = [HF_URL_PIPELINE, HF_URL_MODELS] if HF_PIPELINE_FIRST else [HF_URL_MODELS, HF_URL_PIPELINE]
|
| 161 |
last_exc: Optional[Exception] = None
|
| 162 |
-
|
| 163 |
for idx, url in enumerate(urls, 1):
|
| 164 |
try:
|
| 165 |
if "/models/" in url:
|
|
@@ -190,10 +291,8 @@ def _hf_post_embeddings_once(batch: List[str]) -> Tuple[np.ndarray, int]:
|
|
| 190 |
def _di_post_embeddings_once(batch: List[str]) -> Tuple[np.ndarray, int]:
|
| 191 |
if not DI_TOKEN:
|
| 192 |
raise RuntimeError("DEEPINFRA_API_KEY manquant (backend=deepinfra).")
|
| 193 |
-
# OpenAI-compatible embeddings endpoint
|
| 194 |
headers = {"Authorization": f"Bearer {DI_TOKEN}", "Content-Type": "application/json", "Accept": "application/json"}
|
| 195 |
payload = {"model": DI_MODEL, "input": batch}
|
| 196 |
-
# NB: on peut aussi ajouter "encoding_format":"float" si nécessaire
|
| 197 |
r = requests.post(DI_URL, headers=headers, json=payload, timeout=DI_TIMEOUT)
|
| 198 |
size = int(r.headers.get("Content-Length", "0"))
|
| 199 |
if r.status_code >= 400:
|
|
@@ -212,6 +311,11 @@ def _di_post_embeddings_once(batch: List[str]) -> Tuple[np.ndarray, int]:
|
|
| 212 |
return arr.astype(np.float32), size
|
| 213 |
|
| 214 |
# ---------- Retry orchestrator ----------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
def _call_with_retries(func, batch: List[str], label: str, job_id: Optional[str] = None) -> Tuple[np.ndarray, int]:
|
| 216 |
last_exc = None
|
| 217 |
for attempt in range(RETRY_MAX):
|
|
@@ -238,13 +342,8 @@ def _call_with_retries(func, batch: List[str], label: str, job_id: Optional[str]
|
|
| 238 |
raise RuntimeError(f"{label}: retries exhausted: {last_exc}")
|
| 239 |
|
| 240 |
def _post_embeddings(batch: List[str], job_id: Optional[str] = None) -> Tuple[np.ndarray, int]:
|
| 241 |
-
"""
|
| 242 |
-
Essaie les backends dans EMB_BACKEND_ORDER avec retries.
|
| 243 |
-
Auto-fallback optionnel vers DeepInfra si HF renvoie la Similarity.
|
| 244 |
-
"""
|
| 245 |
last_err = None
|
| 246 |
similarity_misroute = False
|
| 247 |
-
|
| 248 |
for b in EMB_BACKEND_ORDER:
|
| 249 |
if b == "hf":
|
| 250 |
try:
|
|
@@ -265,25 +364,13 @@ def _post_embeddings(batch: List[str], job_id: Optional[str] = None) -> Tuple[np
|
|
| 265 |
LOG.error(f"DeepInfra failed: {e}")
|
| 266 |
else:
|
| 267 |
_append_log(job_id, f"Backend inconnu ignoré: {b}")
|
| 268 |
-
|
| 269 |
if ALLOW_DI_AUTOFALLBACK and similarity_misroute and DI_TOKEN:
|
| 270 |
LOG.warning("HF a routé sur SentenceSimilarity => auto-fallback DeepInfra (override ordre).")
|
| 271 |
_append_log(job_id, "Auto-fallback DeepInfra (HF => SentenceSimilarity).")
|
| 272 |
return _call_with_retries(_di_post_embeddings_once, batch, "DeepInfra", job_id)
|
| 273 |
-
|
| 274 |
raise RuntimeError(f"Tous les backends ont échoué: {last_err}")
|
| 275 |
|
| 276 |
-
# ----------
|
| 277 |
-
def _ensure_collection(name: str, dim: int):
|
| 278 |
-
try:
|
| 279 |
-
qdr.get_collection(name); return
|
| 280 |
-
except Exception:
|
| 281 |
-
pass
|
| 282 |
-
qdr.create_collection(
|
| 283 |
-
collection_name=name,
|
| 284 |
-
vectors_config=VectorParams(size=dim, distance=Distance.COSINE),
|
| 285 |
-
)
|
| 286 |
-
|
| 287 |
def _chunk_with_spans(text: str, size: int, overlap: int):
|
| 288 |
n = len(text or "")
|
| 289 |
if size <= 0:
|
|
@@ -300,56 +387,49 @@ def run_index_job(job_id: str, req: IndexRequest):
|
|
| 300 |
try:
|
| 301 |
_set_status(job_id, "running")
|
| 302 |
total_chunks = 0
|
| 303 |
-
_append_log(job_id, f"Start project={req.project_id} files={len(req.files)} | backends={EMB_BACKEND_ORDER}")
|
| 304 |
LOG.info(f"[{job_id}] Index start project={req.project_id} files={len(req.files)}")
|
| 305 |
|
|
|
|
| 306 |
warm = "warmup"
|
| 307 |
if req.files:
|
| 308 |
for _, _, chunk_txt in _chunk_with_spans(req.files[0].text or "", req.chunk_size, req.overlap):
|
| 309 |
if (chunk_txt or "").strip():
|
| 310 |
warm = chunk_txt; break
|
| 311 |
-
embs,
|
| 312 |
dim = embs.shape[1]
|
| 313 |
col = f"proj_{req.project_id}"
|
| 314 |
-
|
| 315 |
_append_log(job_id, f"Collection ready: {col} (dim={dim})")
|
| 316 |
|
| 317 |
-
|
| 318 |
for fi, f in enumerate(req.files, 1):
|
| 319 |
if not (f.text or "").strip():
|
| 320 |
_append_log(job_id, f"file {fi}: vide — ignoré")
|
| 321 |
continue
|
| 322 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
for ci, (start, end, chunk_txt) in enumerate(_chunk_with_spans(f.text, req.chunk_size, req.overlap)):
|
| 324 |
-
if not (chunk_txt or "").strip():
|
| 325 |
continue
|
| 326 |
-
|
| 327 |
meta = {"path": f.path, "chunk": ci, "start": start, "end": end}
|
| 328 |
if req.store_text:
|
| 329 |
meta["text"] = chunk_txt
|
| 330 |
metas.append(meta)
|
| 331 |
-
if len(
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
for k, vec in enumerate(vecs)
|
| 336 |
-
]
|
| 337 |
-
qdr.upsert(collection_name=col, points=batch_points)
|
| 338 |
-
point_id += len(batch_points)
|
| 339 |
-
total_chunks += len(chunks)
|
| 340 |
-
_append_log(job_id, f"file {fi}/{len(req.files)}: +{len(chunks)} chunks (total={total_chunks}) ~{sz/1024:.1f}KiB")
|
| 341 |
-
chunks, metas = [], []
|
| 342 |
-
|
| 343 |
-
if chunks:
|
| 344 |
-
vecs, sz = _post_embeddings(chunks, job_id=job_id)
|
| 345 |
-
batch_points = [
|
| 346 |
-
PointStruct(id=point_id + k, vector=vec.tolist(), payload=metas[k])
|
| 347 |
-
for k, vec in enumerate(vecs)
|
| 348 |
-
]
|
| 349 |
-
qdr.upsert(collection_name=col, points=batch_points)
|
| 350 |
-
point_id += len(batch_points)
|
| 351 |
-
total_chunks += len(chunks)
|
| 352 |
-
_append_log(job_id, f"file {fi}/{len(req.files)}: +{len(chunks)} chunks (total={total_chunks}) ~{sz/1024:.1f}KiB")
|
| 353 |
|
| 354 |
_append_log(job_id, f"Done. chunks={total_chunks}")
|
| 355 |
_set_status(job_id, "done")
|
|
@@ -372,6 +452,8 @@ def root():
|
|
| 372 |
"hf_url_models": HF_URL_MODELS if "hf" in EMB_BACKEND_ORDER else None,
|
| 373 |
"di_url": DI_URL if "deepinfra" in EMB_BACKEND_ORDER else None,
|
| 374 |
"di_model": DI_MODEL if "deepinfra" in EMB_BACKEND_ORDER else None,
|
|
|
|
|
|
|
| 375 |
"docs": "/health, /index, /status/{job_id}, /query, /wipe"
|
| 376 |
}
|
| 377 |
|
|
@@ -409,7 +491,7 @@ def status(job_id: str, x_auth_token: Optional[str] = Header(default=None)):
|
|
| 409 |
raise HTTPException(404, "job inconnu")
|
| 410 |
return {"status": j["status"], "logs": j["logs"][-800:]}
|
| 411 |
|
| 412 |
-
#
|
| 413 |
@app.get("/status")
|
| 414 |
def status_qp(job_id: str = Query(None), x_auth_token: Optional[str] = Header(default=None)):
|
| 415 |
if AUTH_TOKEN and (x_auth_token or "") != AUTH_TOKEN:
|
|
@@ -439,15 +521,21 @@ def query(req: QueryRequest, x_auth_token: Optional[str] = Header(default=None))
|
|
| 439 |
raise HTTPException(401, "Unauthorized")
|
| 440 |
_check_backend_ready()
|
| 441 |
k = int(max(1, min(50, req.top_k or 6)))
|
|
|
|
| 442 |
vecs, _ = _post_embeddings([req.query])
|
| 443 |
col = f"proj_{req.project_id}"
|
|
|
|
|
|
|
| 444 |
try:
|
| 445 |
-
|
| 446 |
except Exception as e:
|
| 447 |
raise HTTPException(400, f"Search failed: {e}")
|
|
|
|
| 448 |
out = []
|
| 449 |
-
|
| 450 |
-
|
|
|
|
|
|
|
| 451 |
txt = pl.get("text")
|
| 452 |
if txt and len(txt) > 800:
|
| 453 |
txt = txt[:800] + "..."
|
|
@@ -457,7 +545,7 @@ def query(req: QueryRequest, x_auth_token: Optional[str] = Header(default=None))
|
|
| 457 |
"start": pl.get("start"),
|
| 458 |
"end": pl.get("end"),
|
| 459 |
"text": txt,
|
| 460 |
-
"score":
|
| 461 |
})
|
| 462 |
return {"results": out}
|
| 463 |
|
|
@@ -467,7 +555,7 @@ def wipe_collection(project_id: str, x_auth_token: Optional[str] = Header(defaul
|
|
| 467 |
raise HTTPException(401, "Unauthorized")
|
| 468 |
col = f"proj_{project_id}"
|
| 469 |
try:
|
| 470 |
-
|
| 471 |
except Exception as e:
|
| 472 |
raise HTTPException(400, f"wipe failed: {e}")
|
| 473 |
|
|
|
|
| 7 |
import requests
|
| 8 |
from fastapi import FastAPI, BackgroundTasks, Header, HTTPException, Query
|
| 9 |
from pydantic import BaseModel, Field
|
| 10 |
+
|
| 11 |
+
# Qdrant (optionnel si VECTOR_STORE=memory)
|
| 12 |
+
try:
|
| 13 |
+
from qdrant_client import QdrantClient
|
| 14 |
+
from qdrant_client.http.models import VectorParams, Distance, PointStruct
|
| 15 |
+
except Exception: # si non installé, on retombe en mémoire
|
| 16 |
+
QdrantClient = None
|
| 17 |
+
VectorParams = Distance = PointStruct = None
|
| 18 |
|
| 19 |
# ---------- logging ----------
|
| 20 |
logging.basicConfig(level=logging.INFO, format="%(levelname)s:%(name)s:%(message)s")
|
| 21 |
LOG = logging.getLogger("remote_indexer")
|
| 22 |
|
| 23 |
# ---------- ENV (config) ----------
|
| 24 |
+
# Choix du store: "qdrant" (par défaut) ou "memory"
|
| 25 |
+
VECTOR_STORE = os.getenv("VECTOR_STORE", "qdrant").strip().lower()
|
| 26 |
+
|
| 27 |
+
# Ordre des backends d'embeddings à essayer. Par défaut: DeepInfra, puis HF.
|
| 28 |
DEFAULT_BACKENDS = "deepinfra,hf"
|
| 29 |
+
EMB_BACKEND_ORDER = [s.strip().lower()
|
| 30 |
+
for s in os.getenv("EMB_BACKEND_ORDER", os.getenv("EMB_BACKEND", DEFAULT_BACKENDS)).split(",")
|
| 31 |
+
if s.strip()]
|
| 32 |
|
|
|
|
| 33 |
ALLOW_DI_AUTOFALLBACK = os.getenv("ALLOW_DI_AUTOFALLBACK", "true").lower() in ("1","true","yes","on")
|
| 34 |
|
| 35 |
# HF Inference API
|
|
|
|
| 52 |
HF_WAIT = os.getenv("HF_WAIT_FOR_MODEL", "true").lower() in ("1","true","yes","on")
|
| 53 |
HF_PIPELINE_FIRST = os.getenv("HF_PIPELINE_FIRST", "true").lower() in ("1","true","yes","on")
|
| 54 |
|
| 55 |
+
# DeepInfra (OpenAI-compatible embeddings)
|
| 56 |
DI_TOKEN = os.getenv("DEEPINFRA_API_KEY", "").strip()
|
|
|
|
| 57 |
DI_MODEL = os.getenv("DEEPINFRA_EMBED_MODEL", "BAAI/bge-m3").strip()
|
|
|
|
| 58 |
DI_URL = os.getenv("DEEPINFRA_EMBED_URL", "https://api.deepinfra.com/v1/openai/embeddings").strip()
|
| 59 |
DI_TIMEOUT = float(os.getenv("EMB_TIMEOUT_SEC", "120"))
|
| 60 |
|
| 61 |
+
# Retries embeddings
|
| 62 |
RETRY_MAX = int(os.getenv("EMB_RETRY_MAX", "6"))
|
| 63 |
RETRY_BASE_SEC = float(os.getenv("EMB_RETRY_BASE", "1.5"))
|
| 64 |
RETRY_JITTER = float(os.getenv("EMB_RETRY_JITTER", "0.35"))
|
| 65 |
|
| 66 |
# Qdrant
|
| 67 |
+
QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333").strip()
|
| 68 |
QDRANT_API = os.getenv("QDRANT_API_KEY", "").strip()
|
| 69 |
|
| 70 |
+
# Auth d’API du service (simple header)
|
| 71 |
AUTH_TOKEN = os.getenv("REMOTE_INDEX_TOKEN", "").strip()
|
| 72 |
|
| 73 |
LOG.info(f"Embeddings backend order = {EMB_BACKEND_ORDER}")
|
| 74 |
LOG.info(f"HF pipeline URL = {HF_URL_PIPELINE}")
|
| 75 |
LOG.info(f"HF models URL = {HF_URL_MODELS}")
|
| 76 |
+
LOG.info(f"VECTOR_STORE = {VECTOR_STORE}")
|
| 77 |
if "hf" in EMB_BACKEND_ORDER and not HF_TOKEN:
|
| 78 |
LOG.warning("HF_API_TOKEN manquant — tentatives HF échoueront.")
|
| 79 |
if "deepinfra" in EMB_BACKEND_ORDER and not DI_TOKEN:
|
| 80 |
LOG.warning("DEEPINFRA_API_KEY manquant — tentatives DeepInfra échoueront.")
|
| 81 |
|
| 82 |
+
# ---------- Vector store abstraction ----------
|
| 83 |
+
class VectorStoreBase:
|
| 84 |
+
def ensure_collection(self, name: str, dim: int): ...
|
| 85 |
+
def upsert(self, name: str, vectors: np.ndarray, payloads: List[dict]) -> int: ...
|
| 86 |
+
def search(self, name: str, query_vec: np.ndarray, limit: int):
|
| 87 |
+
"""return list of objects with .score and .payload"""
|
| 88 |
+
...
|
| 89 |
+
def wipe(self, name: str): ...
|
| 90 |
+
|
| 91 |
+
class MemoryHit:
|
| 92 |
+
def __init__(self, score: float, payload: dict):
|
| 93 |
+
self.score = score
|
| 94 |
+
self.payload = payload
|
| 95 |
+
|
| 96 |
+
class MemoryStore(VectorStoreBase):
|
| 97 |
+
"""Simple store en mémoire (cosine sur vecteurs normalisés). Persistance: vie du process."""
|
| 98 |
+
def __init__(self):
|
| 99 |
+
self.data: Dict[str, Dict[str, Any]] = {} # {col: {"dim": d, "vecs": np.ndarray [N,d], "payloads": List[dict]}}
|
| 100 |
+
LOG.warning("Vector store: MEMORY (fallback). Les données sont volatiles (perdues au restart).")
|
| 101 |
+
|
| 102 |
+
def ensure_collection(self, name: str, dim: int):
|
| 103 |
+
col = self.data.get(name)
|
| 104 |
+
if not col:
|
| 105 |
+
self.data[name] = {"dim": dim, "vecs": np.zeros((0, dim), dtype=np.float32), "payloads": []}
|
| 106 |
+
|
| 107 |
+
def upsert(self, name: str, vectors: np.ndarray, payloads: List[dict]) -> int:
|
| 108 |
+
self.ensure_collection(name, vectors.shape[1])
|
| 109 |
+
col = self.data[name]
|
| 110 |
+
if vectors.ndim != 2 or vectors.shape[1] != col["dim"]:
|
| 111 |
+
raise RuntimeError(f"MemoryStore: bad shape {vectors.shape}, expected (*,{col['dim']})")
|
| 112 |
+
col["vecs"] = np.vstack([col["vecs"], vectors.astype(np.float32)])
|
| 113 |
+
col["payloads"].extend(payloads)
|
| 114 |
+
return vectors.shape[0]
|
| 115 |
+
|
| 116 |
+
def search(self, name: str, query_vec: np.ndarray, limit: int):
|
| 117 |
+
col = self.data.get(name)
|
| 118 |
+
if not col or col["vecs"].shape[0] == 0:
|
| 119 |
+
return []
|
| 120 |
+
V = col["vecs"] # [N,d], déjà normalisés
|
| 121 |
+
q = query_vec.reshape(1, -1) # [1,d]
|
| 122 |
+
scores = (V @ q.T).ravel() # cos sim
|
| 123 |
+
idx = np.argsort(-scores)[:limit]
|
| 124 |
+
return [MemoryHit(float(scores[i]), col["payloads"][i]) for i in idx]
|
| 125 |
+
|
| 126 |
+
def wipe(self, name: str):
|
| 127 |
+
if name in self.data:
|
| 128 |
+
del self.data[name]
|
| 129 |
+
|
| 130 |
+
class QdrantStore(VectorStoreBase):
|
| 131 |
+
def __init__(self, url: str, api_key: Optional[str]):
|
| 132 |
+
if QdrantClient is None:
|
| 133 |
+
raise RuntimeError("qdrant-client non installé.")
|
| 134 |
+
self.client = QdrantClient(url=url, api_key=api_key if api_key else None)
|
| 135 |
+
# ping rapide
|
| 136 |
+
try:
|
| 137 |
+
_ = self.client.get_collections()
|
| 138 |
+
LOG.info("Connecté à Qdrant.")
|
| 139 |
+
except Exception as e:
|
| 140 |
+
raise RuntimeError(f"Connexion Qdrant impossible: {e}")
|
| 141 |
+
|
| 142 |
+
def ensure_collection(self, name: str, dim: int):
|
| 143 |
+
try:
|
| 144 |
+
self.client.get_collection(name); return
|
| 145 |
+
except Exception:
|
| 146 |
+
pass
|
| 147 |
+
self.client.create_collection(
|
| 148 |
+
collection_name=name,
|
| 149 |
+
vectors_config=VectorParams(size=dim, distance=Distance.COSINE),
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
def upsert(self, name: str, vectors: np.ndarray, payloads: List[dict]) -> int:
|
| 153 |
+
points = [
|
| 154 |
+
PointStruct(id=None, vector=v.tolist(), payload=payloads[i])
|
| 155 |
+
for i, v in enumerate(vectors)
|
| 156 |
+
]
|
| 157 |
+
self.client.upsert(collection_name=name, points=points)
|
| 158 |
+
return len(points)
|
| 159 |
+
|
| 160 |
+
def search(self, name: str, query_vec: np.ndarray, limit: int):
|
| 161 |
+
res = self.client.search(collection_name=name, query_vector=query_vec.tolist(), limit=limit)
|
| 162 |
+
return res
|
| 163 |
+
|
| 164 |
+
def wipe(self, name: str):
|
| 165 |
+
self.client.delete_collection(name)
|
| 166 |
+
|
| 167 |
+
# Sélection / auto-fallback du store
|
| 168 |
+
STORE: VectorStoreBase
|
| 169 |
+
def _init_store() -> VectorStoreBase:
|
| 170 |
+
prefer = VECTOR_STORE
|
| 171 |
+
if prefer == "memory":
|
| 172 |
+
return MemoryStore()
|
| 173 |
+
|
| 174 |
+
# prefer qdrant
|
| 175 |
+
try:
|
| 176 |
+
return QdrantStore(QDRANT_URL, QDRANT_API if QDRANT_API else None)
|
| 177 |
+
except Exception as e:
|
| 178 |
+
LOG.error(f"Qdrant indisponible ({e}) — fallback en mémoire.")
|
| 179 |
+
return MemoryStore()
|
| 180 |
+
|
| 181 |
+
STORE = _init_store()
|
| 182 |
|
| 183 |
# ---------- Pydantic ----------
|
| 184 |
class FileIn(BaseModel):
|
|
|
|
| 246 |
|
| 247 |
data = r.json()
|
| 248 |
arr = np.array(data, dtype=np.float32)
|
| 249 |
+
if arr.ndim == 3:
|
| 250 |
arr = arr.mean(axis=1)
|
| 251 |
+
elif arr.ndim == 1:
|
|
|
|
|
|
|
| 252 |
arr = arr.reshape(1, -1)
|
| 253 |
+
if arr.ndim != 2:
|
| 254 |
raise RuntimeError(f"HF: unexpected embeddings shape: {arr.shape}")
|
| 255 |
|
| 256 |
norms = np.linalg.norm(arr, axis=1, keepdims=True) + 1e-12
|
|
|
|
| 261 |
payload: Dict[str, Any] = {"inputs": (batch if len(batch) > 1 else batch[0])}
|
| 262 |
urls = [HF_URL_PIPELINE, HF_URL_MODELS] if HF_PIPELINE_FIRST else [HF_URL_MODELS, HF_URL_PIPELINE]
|
| 263 |
last_exc: Optional[Exception] = None
|
|
|
|
| 264 |
for idx, url in enumerate(urls, 1):
|
| 265 |
try:
|
| 266 |
if "/models/" in url:
|
|
|
|
| 291 |
def _di_post_embeddings_once(batch: List[str]) -> Tuple[np.ndarray, int]:
|
| 292 |
if not DI_TOKEN:
|
| 293 |
raise RuntimeError("DEEPINFRA_API_KEY manquant (backend=deepinfra).")
|
|
|
|
| 294 |
headers = {"Authorization": f"Bearer {DI_TOKEN}", "Content-Type": "application/json", "Accept": "application/json"}
|
| 295 |
payload = {"model": DI_MODEL, "input": batch}
|
|
|
|
| 296 |
r = requests.post(DI_URL, headers=headers, json=payload, timeout=DI_TIMEOUT)
|
| 297 |
size = int(r.headers.get("Content-Length", "0"))
|
| 298 |
if r.status_code >= 400:
|
|
|
|
| 311 |
return arr.astype(np.float32), size
|
| 312 |
|
| 313 |
# ---------- Retry orchestrator ----------
|
| 314 |
+
def _retry_sleep(attempt: int):
|
| 315 |
+
back = (RETRY_BASE_SEC ** attempt)
|
| 316 |
+
jitter = 1.0 + random.uniform(-RETRY_JITTER, RETRY_JITTER)
|
| 317 |
+
return max(0.25, back * jitter)
|
| 318 |
+
|
| 319 |
def _call_with_retries(func, batch: List[str], label: str, job_id: Optional[str] = None) -> Tuple[np.ndarray, int]:
|
| 320 |
last_exc = None
|
| 321 |
for attempt in range(RETRY_MAX):
|
|
|
|
| 342 |
raise RuntimeError(f"{label}: retries exhausted: {last_exc}")
|
| 343 |
|
| 344 |
def _post_embeddings(batch: List[str], job_id: Optional[str] = None) -> Tuple[np.ndarray, int]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
last_err = None
|
| 346 |
similarity_misroute = False
|
|
|
|
| 347 |
for b in EMB_BACKEND_ORDER:
|
| 348 |
if b == "hf":
|
| 349 |
try:
|
|
|
|
| 364 |
LOG.error(f"DeepInfra failed: {e}")
|
| 365 |
else:
|
| 366 |
_append_log(job_id, f"Backend inconnu ignoré: {b}")
|
|
|
|
| 367 |
if ALLOW_DI_AUTOFALLBACK and similarity_misroute and DI_TOKEN:
|
| 368 |
LOG.warning("HF a routé sur SentenceSimilarity => auto-fallback DeepInfra (override ordre).")
|
| 369 |
_append_log(job_id, "Auto-fallback DeepInfra (HF => SentenceSimilarity).")
|
| 370 |
return _call_with_retries(_di_post_embeddings_once, batch, "DeepInfra", job_id)
|
|
|
|
| 371 |
raise RuntimeError(f"Tous les backends ont échoué: {last_err}")
|
| 372 |
|
| 373 |
+
# ---------- Chunking ----------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
def _chunk_with_spans(text: str, size: int, overlap: int):
|
| 375 |
n = len(text or "")
|
| 376 |
if size <= 0:
|
|
|
|
| 387 |
try:
|
| 388 |
_set_status(job_id, "running")
|
| 389 |
total_chunks = 0
|
| 390 |
+
_append_log(job_id, f"Start project={req.project_id} files={len(req.files)} | backends={EMB_BACKEND_ORDER} | store={VECTOR_STORE}")
|
| 391 |
LOG.info(f"[{job_id}] Index start project={req.project_id} files={len(req.files)}")
|
| 392 |
|
| 393 |
+
# Warmup -> dimension
|
| 394 |
warm = "warmup"
|
| 395 |
if req.files:
|
| 396 |
for _, _, chunk_txt in _chunk_with_spans(req.files[0].text or "", req.chunk_size, req.overlap):
|
| 397 |
if (chunk_txt or "").strip():
|
| 398 |
warm = chunk_txt; break
|
| 399 |
+
embs, _ = _post_embeddings([warm], job_id=job_id)
|
| 400 |
dim = embs.shape[1]
|
| 401 |
col = f"proj_{req.project_id}"
|
| 402 |
+
STORE.ensure_collection(col, dim)
|
| 403 |
_append_log(job_id, f"Collection ready: {col} (dim={dim})")
|
| 404 |
|
| 405 |
+
# loop fichiers
|
| 406 |
for fi, f in enumerate(req.files, 1):
|
| 407 |
if not (f.text or "").strip():
|
| 408 |
_append_log(job_id, f"file {fi}: vide — ignoré")
|
| 409 |
continue
|
| 410 |
+
|
| 411 |
+
batch_txts, metas = [], []
|
| 412 |
+
def _flush():
|
| 413 |
+
nonlocal batch_txts, metas, total_chunks
|
| 414 |
+
if not batch_txts: return
|
| 415 |
+
vecs, sz = _post_embeddings(batch_txts, job_id=job_id)
|
| 416 |
+
added = STORE.upsert(col, vecs, metas)
|
| 417 |
+
total_chunks += added
|
| 418 |
+
_append_log(job_id, f"file {fi}/{len(req.files)}: +{added} chunks (total={total_chunks})")
|
| 419 |
+
batch_txts, metas = [], []
|
| 420 |
+
|
| 421 |
for ci, (start, end, chunk_txt) in enumerate(_chunk_with_spans(f.text, req.chunk_size, req.overlap)):
|
| 422 |
+
if not (chunk_txt or "").strip():
|
| 423 |
continue
|
| 424 |
+
batch_txts.append(chunk_txt)
|
| 425 |
meta = {"path": f.path, "chunk": ci, "start": start, "end": end}
|
| 426 |
if req.store_text:
|
| 427 |
meta["text"] = chunk_txt
|
| 428 |
metas.append(meta)
|
| 429 |
+
if len(batch_txts) >= req.batch_size:
|
| 430 |
+
_flush()
|
| 431 |
+
|
| 432 |
+
_flush()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 433 |
|
| 434 |
_append_log(job_id, f"Done. chunks={total_chunks}")
|
| 435 |
_set_status(job_id, "done")
|
|
|
|
| 452 |
"hf_url_models": HF_URL_MODELS if "hf" in EMB_BACKEND_ORDER else None,
|
| 453 |
"di_url": DI_URL if "deepinfra" in EMB_BACKEND_ORDER else None,
|
| 454 |
"di_model": DI_MODEL if "deepinfra" in EMB_BACKEND_ORDER else None,
|
| 455 |
+
"vector_store": VECTOR_STORE,
|
| 456 |
+
"vector_store_active": type(STORE).__name__,
|
| 457 |
"docs": "/health, /index, /status/{job_id}, /query, /wipe"
|
| 458 |
}
|
| 459 |
|
|
|
|
| 491 |
raise HTTPException(404, "job inconnu")
|
| 492 |
return {"status": j["status"], "logs": j["logs"][-800:]}
|
| 493 |
|
| 494 |
+
# Legacy compat
|
| 495 |
@app.get("/status")
|
| 496 |
def status_qp(job_id: str = Query(None), x_auth_token: Optional[str] = Header(default=None)):
|
| 497 |
if AUTH_TOKEN and (x_auth_token or "") != AUTH_TOKEN:
|
|
|
|
| 521 |
raise HTTPException(401, "Unauthorized")
|
| 522 |
_check_backend_ready()
|
| 523 |
k = int(max(1, min(50, req.top_k or 6)))
|
| 524 |
+
|
| 525 |
vecs, _ = _post_embeddings([req.query])
|
| 526 |
col = f"proj_{req.project_id}"
|
| 527 |
+
|
| 528 |
+
# Recherche selon le store actif
|
| 529 |
try:
|
| 530 |
+
hits = STORE.search(col, vecs[0], k)
|
| 531 |
except Exception as e:
|
| 532 |
raise HTTPException(400, f"Search failed: {e}")
|
| 533 |
+
|
| 534 |
out = []
|
| 535 |
+
# Qdrant renvoie des objets avec .score, .payload
|
| 536 |
+
for p in hits:
|
| 537 |
+
pl = getattr(p, "payload", None) or {}
|
| 538 |
+
score = float(getattr(p, "score", 0.0))
|
| 539 |
txt = pl.get("text")
|
| 540 |
if txt and len(txt) > 800:
|
| 541 |
txt = txt[:800] + "..."
|
|
|
|
| 545 |
"start": pl.get("start"),
|
| 546 |
"end": pl.get("end"),
|
| 547 |
"text": txt,
|
| 548 |
+
"score": score,
|
| 549 |
})
|
| 550 |
return {"results": out}
|
| 551 |
|
|
|
|
| 555 |
raise HTTPException(401, "Unauthorized")
|
| 556 |
col = f"proj_{project_id}"
|
| 557 |
try:
|
| 558 |
+
STORE.wipe(col); return {"ok": True}
|
| 559 |
except Exception as e:
|
| 560 |
raise HTTPException(400, f"wipe failed: {e}")
|
| 561 |
|