Junbrobro commited on
Commit
978477f
ยท
1 Parent(s): 28b7152

Fix vector DB loading: robust unzip + no re-embedding

Browse files
Files changed (2) hide show
  1. src/chatbot.py +20 -4
  2. src/retrieval.py +40 -32
src/chatbot.py CHANGED
@@ -9,17 +9,33 @@ import sys
9
  from pathlib import Path
10
  from typing import List, Dict
11
 
 
12
  PROJECT_ROOT = Path(__file__).parent.parent
13
  sys.path.insert(0, str(PROJECT_ROOT / "src"))
14
 
 
 
 
15
  from retrieval import FaissRetriever
16
- from embedding import EmbeddingModel
17
  from llmmodel import get_llm
18
  from prompt import (
19
  create_rag_prompt,
20
  format_answer
21
  )
22
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  class SocialCultureChatbot:
25
  """
@@ -38,7 +54,7 @@ class SocialCultureChatbot:
38
  print("=" * 60)
39
 
40
  # ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ (encode only)
41
- print("\n๐Ÿ”ค ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ๋กœ๋“œ ์ค‘...")
42
  self.embedding_model = EmbeddingModel()
43
 
44
  # FAISS ๊ฒ€์ƒ‰๊ธฐ
@@ -60,7 +76,7 @@ class SocialCultureChatbot:
60
  """
61
  ์งˆ๋ฌธ์— ๋‹ต๋ณ€ํ•ฉ๋‹ˆ๋‹ค.
62
  """
63
- # 1. ๊ฒ€์ƒ‰
64
  results = self.retriever.retrieve(
65
  query=question,
66
  top_k=self.top_k
@@ -71,7 +87,7 @@ class SocialCultureChatbot:
71
  [f"[{i+1}] {r['text']}" for i, r in enumerate(results)]
72
  )
73
 
74
- # 3. ํ”„๋กฌํ”„ํŠธ
75
  prompt = create_rag_prompt(question, context)
76
 
77
  # 4. LLM ํ˜ธ์ถœ
 
9
  from pathlib import Path
10
  from typing import List, Dict
11
 
12
+ # ํ”„๋กœ์ ํŠธ ๋ฃจํŠธ ์„ค์ •
13
  PROJECT_ROOT = Path(__file__).parent.parent
14
  sys.path.insert(0, str(PROJECT_ROOT / "src"))
15
 
16
+ # =========================
17
+ # imports
18
+ # =========================
19
  from retrieval import FaissRetriever
20
+ from embedding import embed_query
21
  from llmmodel import get_llm
22
  from prompt import (
23
  create_rag_prompt,
24
  format_answer
25
  )
26
 
27
+ # =========================
28
+ # ์ตœ์†Œ ๋ž˜ํผ: EmbeddingModel
29
+ # =========================
30
+ class EmbeddingModel:
31
+ """
32
+ FAISS ๊ฒ€์ƒ‰์šฉ ์ฟผ๋ฆฌ ์ž„๋ฒ ๋”ฉ ๋ž˜ํผ
33
+ - vector DB๋Š” ์‚ฌ์ „ ๊ตฌ์ถ•๋จ
34
+ - encode(query)๋งŒ ์ œ๊ณต
35
+ """
36
+ def encode(self, text: str):
37
+ return embed_query(text)
38
+
39
 
40
  class SocialCultureChatbot:
41
  """
 
54
  print("=" * 60)
55
 
56
  # ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ (encode only)
57
+ print("\n๐Ÿ”ค ์ž„๋ฒ ๋”ฉ ๋ž˜ํผ ์ดˆ๊ธฐํ™” ์ค‘...")
58
  self.embedding_model = EmbeddingModel()
59
 
60
  # FAISS ๊ฒ€์ƒ‰๊ธฐ
 
76
  """
77
  ์งˆ๋ฌธ์— ๋‹ต๋ณ€ํ•ฉ๋‹ˆ๋‹ค.
78
  """
79
+ # 1. FAISS ๊ฒ€์ƒ‰
80
  results = self.retriever.retrieve(
81
  query=question,
82
  top_k=self.top_k
 
87
  [f"[{i+1}] {r['text']}" for i, r in enumerate(results)]
88
  )
89
 
90
+ # 3. ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ
91
  prompt = create_rag_prompt(question, context)
92
 
93
  # 4. LLM ํ˜ธ์ถœ
src/retrieval.py CHANGED
@@ -5,24 +5,39 @@
5
  - โŒ ๋ฌธ์„œ ์žฌ์ž„๋ฒ ๋”ฉ ์—†์Œ
6
  """
7
 
8
- from typing import List, Dict, Tuple
9
  from pathlib import Path
 
10
  import sys
11
 
12
  PROJECT_ROOT = Path(__file__).parent.parent
13
  sys.path.insert(0, str(PROJECT_ROOT / "src"))
14
 
15
- from search_model_setup import get_search_model
16
- from embedding import embed_text
17
 
18
 
19
- class Retriever:
20
  """
21
  FAISS ๊ธฐ๋ฐ˜ RAG ๊ฒ€์ƒ‰๊ธฐ
22
  """
23
 
24
- def __init__(self, collection_name: str = "social_culture"):
25
- self.search_model = get_search_model(collection_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  def retrieve(
28
  self,
@@ -30,34 +45,27 @@ class Retriever:
30
  top_k: int = 5
31
  ) -> List[Dict]:
32
  """
33
- FAISS ๊ฒ€์ƒ‰๋งŒ ์ˆ˜ํ–‰
34
  """
35
- return self.search_model.search(query, top_k=top_k)
36
-
37
- def retrieve_with_context(
38
- self,
39
- query: str,
40
- top_k: int = 5
41
- ) -> Tuple[List[Dict], str]:
42
- """
43
- ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ + context ๋ฌธ์ž์—ด ๋ฐ˜ํ™˜
44
- """
45
- results = self.retrieve(query, top_k=top_k)
46
-
47
- context_parts = []
48
- for i, r in enumerate(results):
49
- page = r["metadata"].get("page_number", "N/A")
50
- context_parts.append(
51
- f"[๋ฌธ์„œ {i+1}] (ํŽ˜์ด์ง€ {page})\n{r['text']}"
52
- )
53
-
54
- context = "\n\n---\n\n".join(context_parts)
55
- return results, context
56
 
 
 
57
 
58
- def get_retriever(collection_name: str = "social_culture") -> Retriever:
59
- return Retriever(collection_name)
 
 
 
60
 
61
- # alias for backward compatibility
62
- FaissRetriever = Retriever
 
 
 
 
 
63
 
 
 
5
  - โŒ ๋ฌธ์„œ ์žฌ์ž„๋ฒ ๋”ฉ ์—†์Œ
6
  """
7
 
8
+ from typing import List, Dict
9
  from pathlib import Path
10
+ import numpy as np
11
  import sys
12
 
13
  PROJECT_ROOT = Path(__file__).parent.parent
14
  sys.path.insert(0, str(PROJECT_ROOT / "src"))
15
 
16
+ import faiss
17
+ import json
18
 
19
 
20
+ class FaissRetriever:
21
  """
22
  FAISS ๊ธฐ๋ฐ˜ RAG ๊ฒ€์ƒ‰๊ธฐ
23
  """
24
 
25
+ def __init__(
26
+ self,
27
+ vector_db_dir: str,
28
+ embedding_model
29
+ ):
30
+ self.vector_db_dir = Path(vector_db_dir)
31
+ self.embedding_model = embedding_model
32
+
33
+ # FAISS index
34
+ index_path = self.vector_db_dir / "faiss_index.bin"
35
+ self.index = faiss.read_index(str(index_path))
36
+
37
+ # metadata
38
+ meta_path = self.vector_db_dir / "embeddings_metadata.json"
39
+ with open(meta_path, "r", encoding="utf-8") as f:
40
+ self.metadata = json.load(f)
41
 
42
  def retrieve(
43
  self,
 
45
  top_k: int = 5
46
  ) -> List[Dict]:
47
  """
48
+ FAISS ๊ฒ€์ƒ‰ ์ˆ˜ํ–‰
49
  """
50
+ # 1. query embedding
51
+ query_vec = self.embedding_model.encode(query)
52
+ query_vec = np.array([query_vec]).astype("float32")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ # 2. FAISS search
55
+ distances, indices = self.index.search(query_vec, top_k)
56
 
57
+ # 3. ๊ฒฐ๊ณผ ๊ตฌ์„ฑ
58
+ results = []
59
+ for rank, idx in enumerate(indices[0]):
60
+ if idx < 0:
61
+ continue
62
 
63
+ meta = self.metadata[idx]
64
+ results.append({
65
+ "rank": rank + 1,
66
+ "score": float(distances[0][rank]),
67
+ "text": meta["text"],
68
+ "metadata": meta.get("metadata", {})
69
+ })
70
 
71
+ return results