import os import pickle import re import json import asyncio import tiktoken from typing import List, Dict from fastapi import FastAPI from fastapi.responses import StreamingResponse from pydantic import BaseModel from dotenv import load_dotenv from langchain_openai import OpenAIEmbeddings, ChatOpenAI from langchain_chroma import Chroma from sentence_transformers import CrossEncoder from langchain_core.prompts import ChatPromptTemplate from langchain_core.output_parsers import StrOutputParser from langchain_core.documents import Document load_dotenv() app = FastAPI(title="Hanyang RAG Chatbot") # --- 설정 및 로드 --- DB_PATH = "./db/chroma" BM25_PATH = "./db/bm25.pkl" MODEL_NAME = "gpt-5-mini" # 임베딩 & DB 로드 (large 모델도 존재함) embeddings = OpenAIEmbeddings(model="text-embedding-3-small") vector_store = Chroma( persist_directory=DB_PATH, embedding_function=embeddings, collection_name="hanyang_rules" ) # BM25 로드 with open(BM25_PATH, "rb") as f: bm25_data = pickle.load(f) bm25 = bm25_data["bm25"] doc_store = bm25_data["documents"] # 토크나이저 설정 try: tokenizer = tiktoken.encoding_for_model(MODEL_NAME) except KeyError: tokenizer = tiktoken.get_encoding("cl100k_base") def tiktoken_tokenizer(text): tokens = tokenizer.encode(text) return [str(t) for t in tokens] # Reranker 로드 (CPU 부하를 줄이기 위해 로컬/캐시 모델 사용) reranker = CrossEncoder("BAAI/bge-reranker-v2-m3") # LLM 설정 llm = ChatOpenAI(model=MODEL_NAME) class ChatRequest(BaseModel): query: str history: List[Dict[str, str]] = [] async def generate_chat_stream(query: str): """ 단계별 진행상황(Log)과 최종 답변(Answer)을 실시간으로 Yield 합니다. 형식: JSON String + \n """ # 1. 검색 단계 yield json.dumps({"type": "log", "content": "🔍 [1/4] Hybrid Search(벡터+BM25) 수행 중..."}) + "\n" await asyncio.sleep(0.1) # --- 검색 로직 --- # 1. Dense dense_results = vector_store.similarity_search_with_score(query, k=10) # 2. BM25 tokenized_query = tiktoken_tokenizer(query) bm25_scores = bm25.get_scores(tokenized_query) top_n_bm25_indices = sorted(range(len(bm25_scores)), key=lambda i: bm25_scores[i], reverse=True)[:10] bm25_results = [doc_store[i] for i in top_n_bm25_indices] # 3. Rule-based rule_results = [] match = re.search(r"제\s*(\d+)\s*조", query) if match: target_article = match.group(1) for doc in doc_store: if doc.metadata.get("article_id") == target_article: rule_results.append(doc) yield json.dumps({"type": "log", "content": f"📊 검색 완료: Dense({len(dense_results)}) / BM25({len(bm25_results)}) / Rule({len(rule_results)})"}) + "\n" # --- RRF Fusion --- yield json.dumps({"type": "log", "content": "🔄 [2/4] RRF 알고리즘으로 결과 통합 중..."}) + "\n" rrf_k = 60 doc_scores = {} doc_obj_map = {} def update_rrf(docs, weight=1.0): for rank, doc in enumerate(docs): unique_key = doc.page_content + doc.metadata.get("source", "") if unique_key not in doc_scores: doc_scores[unique_key] = 0.0 doc_obj_map[unique_key] = doc doc_scores[unique_key] += weight * (1 / (rrf_k + rank + 1)) update_rrf([d[0] for d in dense_results]) update_rrf(bm25_results) update_rrf(rule_results, weight=3.0) sorted_docs = sorted(doc_scores.items(), key=lambda x: x[1], reverse=True) # Reranking 후보군을 10개로 제한 candidates = [doc_obj_map[key] for key, score in sorted_docs[:10]] # --- Reranking --- yield json.dumps({"type": "log", "content": f"⚖️ [3/4] Cross-Encoder 재순위화 (후보 {len(candidates)}개)..."}) + "\n" final_docs = [] if candidates: pairs = [] for doc in candidates: source_prefix = f"[{doc.metadata.get('source', '문서')}] " pairs.append([query, source_prefix + doc.page_content]) # CPU 연산 병목 구간 scores = reranker.predict(pairs) scored_docs = sorted(zip(candidates, scores), key=lambda x: x[1], reverse=True) final_docs = [doc for doc, score in scored_docs[:5]] # 최종 5개 # 상세 로그 전송 for i, (doc, score) in enumerate(scored_docs[:3]): # 상위 3개 점수 공개 log_msg = f" - Rank {i+1}: {doc.metadata.get('source')} (Score: {score:.4f})" yield json.dumps({"type": "log", "content": log_msg}) + "\n" # --- LLM Generation --- yield json.dumps({"type": "log", "content": "🤖 [4/4] GPT-5-mini 기반 답변 생성 중..."}) + "\n" context_text = "" for doc in final_docs: source = doc.metadata.get("source", "Unknown") context_text += f"📄 문서: {source}\n내용: {doc.page_content}\n\n" system_prompt = """당신은 한양대학교 학칙 및 규정 안내 챗봇입니다. 질문에 답변할 때는 반드시 근거 문서의 출처(예: 학칙 제5조, 부전공 시행세칙 등)를 언급하세요. Context에 없는 내용은 답변하지 마세요.""" prompt = ChatPromptTemplate.from_messages([ ("system", system_prompt), ("user", "Context:\n{context}\n\nQuestion: {question}") ]) # Streaming LLM Output chain = prompt | llm | StrOutputParser() # LLM이 토큰을 생성할 때마다 클라이언트로 전송 async for token in chain.astream({"context": context_text, "question": query}): yield json.dumps({"type": "answer", "content": token}) + "\n" yield json.dumps({"type": "log", "content": "✅ 답변 생성 완료."}) + "\n" # 참고 문서 정보 전송 doc_info = [{"source": d.metadata.get("source"), "content": d.page_content[:100]+"..."} for d in final_docs] yield json.dumps({"type": "docs", "content": doc_info}) + "\n" @app.post("/chat") async def chat_endpoint(req: ChatRequest): return StreamingResponse(generate_chat_stream(req.query), media_type="application/x-ndjson") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)