Spaces:
Sleeping
Sleeping
Junbrobro commited on
Commit ยท
978477f
1
Parent(s): 28b7152
Fix vector DB loading: robust unzip + no re-embedding
Browse files- src/chatbot.py +20 -4
- src/retrieval.py +40 -32
src/chatbot.py
CHANGED
|
@@ -9,17 +9,33 @@ import sys
|
|
| 9 |
from pathlib import Path
|
| 10 |
from typing import List, Dict
|
| 11 |
|
|
|
|
| 12 |
PROJECT_ROOT = Path(__file__).parent.parent
|
| 13 |
sys.path.insert(0, str(PROJECT_ROOT / "src"))
|
| 14 |
|
|
|
|
|
|
|
|
|
|
| 15 |
from retrieval import FaissRetriever
|
| 16 |
-
from embedding import
|
| 17 |
from llmmodel import get_llm
|
| 18 |
from prompt import (
|
| 19 |
create_rag_prompt,
|
| 20 |
format_answer
|
| 21 |
)
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
class SocialCultureChatbot:
|
| 25 |
"""
|
|
@@ -38,7 +54,7 @@ class SocialCultureChatbot:
|
|
| 38 |
print("=" * 60)
|
| 39 |
|
| 40 |
# ์๋ฒ ๋ฉ ๋ชจ๋ธ (encode only)
|
| 41 |
-
print("\n๐ค ์๋ฒ ๋ฉ
|
| 42 |
self.embedding_model = EmbeddingModel()
|
| 43 |
|
| 44 |
# FAISS ๊ฒ์๊ธฐ
|
|
@@ -60,7 +76,7 @@ class SocialCultureChatbot:
|
|
| 60 |
"""
|
| 61 |
์ง๋ฌธ์ ๋ต๋ณํฉ๋๋ค.
|
| 62 |
"""
|
| 63 |
-
# 1. ๊ฒ์
|
| 64 |
results = self.retriever.retrieve(
|
| 65 |
query=question,
|
| 66 |
top_k=self.top_k
|
|
@@ -71,7 +87,7 @@ class SocialCultureChatbot:
|
|
| 71 |
[f"[{i+1}] {r['text']}" for i, r in enumerate(results)]
|
| 72 |
)
|
| 73 |
|
| 74 |
-
# 3. ํ๋กฌํํธ
|
| 75 |
prompt = create_rag_prompt(question, context)
|
| 76 |
|
| 77 |
# 4. LLM ํธ์ถ
|
|
|
|
| 9 |
from pathlib import Path
|
| 10 |
from typing import List, Dict
|
| 11 |
|
| 12 |
+
# ํ๋ก์ ํธ ๋ฃจํธ ์ค์
|
| 13 |
PROJECT_ROOT = Path(__file__).parent.parent
|
| 14 |
sys.path.insert(0, str(PROJECT_ROOT / "src"))
|
| 15 |
|
| 16 |
+
# =========================
|
| 17 |
+
# imports
|
| 18 |
+
# =========================
|
| 19 |
from retrieval import FaissRetriever
|
| 20 |
+
from embedding import embed_query
|
| 21 |
from llmmodel import get_llm
|
| 22 |
from prompt import (
|
| 23 |
create_rag_prompt,
|
| 24 |
format_answer
|
| 25 |
)
|
| 26 |
|
| 27 |
+
# =========================
|
| 28 |
+
# ์ต์ ๋ํผ: EmbeddingModel
|
| 29 |
+
# =========================
|
| 30 |
+
class EmbeddingModel:
|
| 31 |
+
"""
|
| 32 |
+
FAISS ๊ฒ์์ฉ ์ฟผ๋ฆฌ ์๋ฒ ๋ฉ ๋ํผ
|
| 33 |
+
- vector DB๋ ์ฌ์ ๊ตฌ์ถ๋จ
|
| 34 |
+
- encode(query)๋ง ์ ๊ณต
|
| 35 |
+
"""
|
| 36 |
+
def encode(self, text: str):
|
| 37 |
+
return embed_query(text)
|
| 38 |
+
|
| 39 |
|
| 40 |
class SocialCultureChatbot:
|
| 41 |
"""
|
|
|
|
| 54 |
print("=" * 60)
|
| 55 |
|
| 56 |
# ์๋ฒ ๋ฉ ๋ชจ๋ธ (encode only)
|
| 57 |
+
print("\n๐ค ์๋ฒ ๋ฉ ๋ํผ ์ด๊ธฐํ ์ค...")
|
| 58 |
self.embedding_model = EmbeddingModel()
|
| 59 |
|
| 60 |
# FAISS ๊ฒ์๊ธฐ
|
|
|
|
| 76 |
"""
|
| 77 |
์ง๋ฌธ์ ๋ต๋ณํฉ๋๋ค.
|
| 78 |
"""
|
| 79 |
+
# 1. FAISS ๊ฒ์
|
| 80 |
results = self.retriever.retrieve(
|
| 81 |
query=question,
|
| 82 |
top_k=self.top_k
|
|
|
|
| 87 |
[f"[{i+1}] {r['text']}" for i, r in enumerate(results)]
|
| 88 |
)
|
| 89 |
|
| 90 |
+
# 3. ํ๋กฌํํธ ์์ฑ
|
| 91 |
prompt = create_rag_prompt(question, context)
|
| 92 |
|
| 93 |
# 4. LLM ํธ์ถ
|
src/retrieval.py
CHANGED
|
@@ -5,24 +5,39 @@
|
|
| 5 |
- โ ๋ฌธ์ ์ฌ์๋ฒ ๋ฉ ์์
|
| 6 |
"""
|
| 7 |
|
| 8 |
-
from typing import List, Dict
|
| 9 |
from pathlib import Path
|
|
|
|
| 10 |
import sys
|
| 11 |
|
| 12 |
PROJECT_ROOT = Path(__file__).parent.parent
|
| 13 |
sys.path.insert(0, str(PROJECT_ROOT / "src"))
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
| 17 |
|
| 18 |
|
| 19 |
-
class
|
| 20 |
"""
|
| 21 |
FAISS ๊ธฐ๋ฐ RAG ๊ฒ์๊ธฐ
|
| 22 |
"""
|
| 23 |
|
| 24 |
-
def __init__(
|
| 25 |
-
self
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
def retrieve(
|
| 28 |
self,
|
|
@@ -30,34 +45,27 @@ class Retriever:
|
|
| 30 |
top_k: int = 5
|
| 31 |
) -> List[Dict]:
|
| 32 |
"""
|
| 33 |
-
FAISS ๊ฒ์
|
| 34 |
"""
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
self,
|
| 39 |
-
query: str,
|
| 40 |
-
top_k: int = 5
|
| 41 |
-
) -> Tuple[List[Dict], str]:
|
| 42 |
-
"""
|
| 43 |
-
๊ฒ์ ๊ฒฐ๊ณผ + context ๋ฌธ์์ด ๋ฐํ
|
| 44 |
-
"""
|
| 45 |
-
results = self.retrieve(query, top_k=top_k)
|
| 46 |
-
|
| 47 |
-
context_parts = []
|
| 48 |
-
for i, r in enumerate(results):
|
| 49 |
-
page = r["metadata"].get("page_number", "N/A")
|
| 50 |
-
context_parts.append(
|
| 51 |
-
f"[๋ฌธ์ {i+1}] (ํ์ด์ง {page})\n{r['text']}"
|
| 52 |
-
)
|
| 53 |
-
|
| 54 |
-
context = "\n\n---\n\n".join(context_parts)
|
| 55 |
-
return results, context
|
| 56 |
|
|
|
|
|
|
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
|
|
|
|
|
| 5 |
- โ ๋ฌธ์ ์ฌ์๋ฒ ๋ฉ ์์
|
| 6 |
"""
|
| 7 |
|
| 8 |
+
from typing import List, Dict
|
| 9 |
from pathlib import Path
|
| 10 |
+
import numpy as np
|
| 11 |
import sys
|
| 12 |
|
| 13 |
PROJECT_ROOT = Path(__file__).parent.parent
|
| 14 |
sys.path.insert(0, str(PROJECT_ROOT / "src"))
|
| 15 |
|
| 16 |
+
import faiss
|
| 17 |
+
import json
|
| 18 |
|
| 19 |
|
| 20 |
+
class FaissRetriever:
|
| 21 |
"""
|
| 22 |
FAISS ๊ธฐ๋ฐ RAG ๊ฒ์๊ธฐ
|
| 23 |
"""
|
| 24 |
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
vector_db_dir: str,
|
| 28 |
+
embedding_model
|
| 29 |
+
):
|
| 30 |
+
self.vector_db_dir = Path(vector_db_dir)
|
| 31 |
+
self.embedding_model = embedding_model
|
| 32 |
+
|
| 33 |
+
# FAISS index
|
| 34 |
+
index_path = self.vector_db_dir / "faiss_index.bin"
|
| 35 |
+
self.index = faiss.read_index(str(index_path))
|
| 36 |
+
|
| 37 |
+
# metadata
|
| 38 |
+
meta_path = self.vector_db_dir / "embeddings_metadata.json"
|
| 39 |
+
with open(meta_path, "r", encoding="utf-8") as f:
|
| 40 |
+
self.metadata = json.load(f)
|
| 41 |
|
| 42 |
def retrieve(
|
| 43 |
self,
|
|
|
|
| 45 |
top_k: int = 5
|
| 46 |
) -> List[Dict]:
|
| 47 |
"""
|
| 48 |
+
FAISS ๊ฒ์ ์ํ
|
| 49 |
"""
|
| 50 |
+
# 1. query embedding
|
| 51 |
+
query_vec = self.embedding_model.encode(query)
|
| 52 |
+
query_vec = np.array([query_vec]).astype("float32")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
+
# 2. FAISS search
|
| 55 |
+
distances, indices = self.index.search(query_vec, top_k)
|
| 56 |
|
| 57 |
+
# 3. ๊ฒฐ๊ณผ ๊ตฌ์ฑ
|
| 58 |
+
results = []
|
| 59 |
+
for rank, idx in enumerate(indices[0]):
|
| 60 |
+
if idx < 0:
|
| 61 |
+
continue
|
| 62 |
|
| 63 |
+
meta = self.metadata[idx]
|
| 64 |
+
results.append({
|
| 65 |
+
"rank": rank + 1,
|
| 66 |
+
"score": float(distances[0][rank]),
|
| 67 |
+
"text": meta["text"],
|
| 68 |
+
"metadata": meta.get("metadata", {})
|
| 69 |
+
})
|
| 70 |
|
| 71 |
+
return results
|