Spaces:
Running
Running
| import os | |
| import pickle | |
| import re | |
| import json | |
| import asyncio | |
| import tiktoken | |
| from typing import List, Dict | |
| from fastapi import FastAPI | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel | |
| from dotenv import load_dotenv | |
| from langchain_openai import OpenAIEmbeddings, ChatOpenAI | |
| from langchain_chroma import Chroma | |
| from sentence_transformers import CrossEncoder | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.documents import Document | |
| load_dotenv() | |
| app = FastAPI(title="Hanyang RAG Chatbot") | |
| # --- μ€μ λ° λ‘λ --- | |
| DB_PATH = "./db/chroma" | |
| BM25_PATH = "./db/bm25.pkl" | |
| MODEL_NAME = "gpt-5-mini" | |
| # μλ² λ© & DB λ‘λ (large λͺ¨λΈλ μ‘΄μ¬ν¨) | |
| embeddings = OpenAIEmbeddings(model="text-embedding-3-small") | |
| vector_store = Chroma( | |
| persist_directory=DB_PATH, | |
| embedding_function=embeddings, | |
| collection_name="hanyang_rules" | |
| ) | |
| # BM25 λ‘λ | |
| with open(BM25_PATH, "rb") as f: | |
| bm25_data = pickle.load(f) | |
| bm25 = bm25_data["bm25"] | |
| doc_store = bm25_data["documents"] | |
| # ν ν¬λμ΄μ μ€μ | |
| try: | |
| tokenizer = tiktoken.encoding_for_model(MODEL_NAME) | |
| except KeyError: | |
| tokenizer = tiktoken.get_encoding("cl100k_base") | |
| def tiktoken_tokenizer(text): | |
| tokens = tokenizer.encode(text) | |
| return [str(t) for t in tokens] | |
| # Reranker λ‘λ (CPU λΆνλ₯Ό μ€μ΄κΈ° μν΄ λ‘컬/μΊμ λͺ¨λΈ μ¬μ©) | |
| reranker = CrossEncoder("BAAI/bge-reranker-v2-m3") | |
| # LLM μ€μ | |
| llm = ChatOpenAI(model=MODEL_NAME) | |
| class ChatRequest(BaseModel): | |
| query: str | |
| history: List[Dict[str, str]] = [] | |
| async def generate_chat_stream(query: str): | |
| """ | |
| λ¨κ³λ³ μ§νμν©(Log)κ³Ό μ΅μ’ λ΅λ³(Answer)μ μ€μκ°μΌλ‘ Yield ν©λλ€. | |
| νμ: JSON String + \n | |
| """ | |
| # 1. κ²μ λ¨κ³ | |
| yield json.dumps({"type": "log", "content": "π [1/4] Hybrid Search(벑ν°+BM25) μν μ€..."}) + "\n" | |
| await asyncio.sleep(0.1) | |
| # --- κ²μ λ‘μ§ --- | |
| # 1. Dense | |
| dense_results = vector_store.similarity_search_with_score(query, k=10) | |
| # 2. BM25 | |
| tokenized_query = tiktoken_tokenizer(query) | |
| bm25_scores = bm25.get_scores(tokenized_query) | |
| top_n_bm25_indices = sorted(range(len(bm25_scores)), key=lambda i: bm25_scores[i], reverse=True)[:10] | |
| bm25_results = [doc_store[i] for i in top_n_bm25_indices] | |
| # 3. Rule-based | |
| rule_results = [] | |
| match = re.search(r"μ \s*(\d+)\s*μ‘°", query) | |
| if match: | |
| target_article = match.group(1) | |
| for doc in doc_store: | |
| if doc.metadata.get("article_id") == target_article: | |
| rule_results.append(doc) | |
| yield json.dumps({"type": "log", "content": f"π κ²μ μλ£: Dense({len(dense_results)}) / BM25({len(bm25_results)}) / Rule({len(rule_results)})"}) + "\n" | |
| # --- RRF Fusion --- | |
| yield json.dumps({"type": "log", "content": "π [2/4] RRF μκ³ λ¦¬μ¦μΌλ‘ κ²°κ³Ό ν΅ν© μ€..."}) + "\n" | |
| rrf_k = 60 | |
| doc_scores = {} | |
| doc_obj_map = {} | |
| def update_rrf(docs, weight=1.0): | |
| for rank, doc in enumerate(docs): | |
| unique_key = doc.page_content + doc.metadata.get("source", "") | |
| if unique_key not in doc_scores: | |
| doc_scores[unique_key] = 0.0 | |
| doc_obj_map[unique_key] = doc | |
| doc_scores[unique_key] += weight * (1 / (rrf_k + rank + 1)) | |
| update_rrf([d[0] for d in dense_results]) | |
| update_rrf(bm25_results) | |
| update_rrf(rule_results, weight=3.0) | |
| sorted_docs = sorted(doc_scores.items(), key=lambda x: x[1], reverse=True) | |
| # Reranking ν보ꡰμ 10κ°λ‘ μ ν | |
| candidates = [doc_obj_map[key] for key, score in sorted_docs[:10]] | |
| # --- Reranking --- | |
| yield json.dumps({"type": "log", "content": f"βοΈ [3/4] Cross-Encoder μ¬μμν (ν보 {len(candidates)}κ°)..."}) + "\n" | |
| final_docs = [] | |
| if candidates: | |
| pairs = [] | |
| for doc in candidates: | |
| source_prefix = f"[{doc.metadata.get('source', 'λ¬Έμ')}] " | |
| pairs.append([query, source_prefix + doc.page_content]) | |
| # CPU μ°μ° λ³λͺ© κ΅¬κ° | |
| scores = reranker.predict(pairs) | |
| scored_docs = sorted(zip(candidates, scores), key=lambda x: x[1], reverse=True) | |
| final_docs = [doc for doc, score in scored_docs[:5]] # μ΅μ’ 5κ° | |
| # μμΈ λ‘κ·Έ μ μ‘ | |
| for i, (doc, score) in enumerate(scored_docs[:3]): # μμ 3κ° μ μ κ³΅κ° | |
| log_msg = f" - Rank {i+1}: {doc.metadata.get('source')} (Score: {score:.4f})" | |
| yield json.dumps({"type": "log", "content": log_msg}) + "\n" | |
| # --- LLM Generation --- | |
| yield json.dumps({"type": "log", "content": "π€ [4/4] GPT-5-mini κΈ°λ° λ΅λ³ μμ± μ€..."}) + "\n" | |
| context_text = "" | |
| for doc in final_docs: | |
| source = doc.metadata.get("source", "Unknown") | |
| context_text += f"π λ¬Έμ: {source}\nλ΄μ©: {doc.page_content}\n\n" | |
| system_prompt = """λΉμ μ νμλνκ΅ νμΉ λ° κ·μ μλ΄ μ±λ΄μ λλ€. | |
| μ§λ¬Έμ λ΅λ³ν λλ λ°λμ κ·Όκ±° λ¬Έμμ μΆμ²(μ: νμΉ μ 5μ‘°, λΆμ 곡 μνμΈμΉ λ±)λ₯Ό μΈκΈνμΈμ. | |
| Contextμ μλ λ΄μ©μ λ΅λ³νμ§ λ§μΈμ.""" | |
| prompt = ChatPromptTemplate.from_messages([ | |
| ("system", system_prompt), | |
| ("user", "Context:\n{context}\n\nQuestion: {question}") | |
| ]) | |
| # Streaming LLM Output | |
| chain = prompt | llm | StrOutputParser() | |
| # LLMμ΄ ν ν°μ μμ±ν λλ§λ€ ν΄λΌμ΄μΈνΈλ‘ μ μ‘ | |
| async for token in chain.astream({"context": context_text, "question": query}): | |
| yield json.dumps({"type": "answer", "content": token}) + "\n" | |
| yield json.dumps({"type": "log", "content": "β λ΅λ³ μμ± μλ£."}) + "\n" | |
| # μ°Έκ³ λ¬Έμ μ 보 μ μ‘ | |
| doc_info = [{"source": d.metadata.get("source"), "content": d.page_content[:100]+"..."} for d in final_docs] | |
| yield json.dumps({"type": "docs", "content": doc_info}) + "\n" | |
| async def chat_endpoint(req: ChatRequest): | |
| return StreamingResponse(generate_chat_stream(req.query), media_type="application/x-ndjson") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |