Spaces:
Sleeping
Sleeping
File size: 4,893 Bytes
c9531de ed521a5 c9531de ed521a5 c9531de ed521a5 c9531de ed521a5 c9531de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import os
import httpx
import gradio as gr
from openai import OpenAI
from qdrant_client import QdrantClient
from sentence_transformers import SentenceTransformer
API_KEY = os.environ.get('DEEPSEEK_API_KEY')
BASE_URL = "https://api.deepseek.com"
QDRANT_PATH = "./qdrant_db"
COLLECTION_NAME = "huggingface_transformers_docs"
EMBEDDING_MODEL_ID = "fyerfyer/finetune-jina-transformers-v1"
class HFRAG:
def __init__(self):
self.embed_model = SentenceTransformer(EMBEDDING_MODEL_ID, trust_remote_code=True)
lock_file = os.path.join(QDRANT_PATH, ".lock")
if os.path.exists(lock_file):
try:
os.remove(lock_file)
print("Cleaned up stale lock file.")
except:
pass
if not os.path.exists(QDRANT_PATH):
raise ValueError(f"Qdrant path not found: {QDRANT_PATH}.")
self.db_client = QdrantClient(path=QDRANT_PATH)
if not self.db_client.collection_exists(COLLECTION_NAME):
raise ValueError(f"Collection '{COLLECTION_NAME}' not found in Qdrant DB.")
print(f"Connected to Qdrant")
self.llm_client = OpenAI(
api_key=API_KEY,
base_url=BASE_URL,
http_client=httpx.Client(proxy=None, trust_env=False)
)
def retrieve(self, query: str, top_k: int = 5, score_threshold: float = 0.40):
query_vector = self.embed_model.encode(query).tolist()
if hasattr(self.db_client, 'search'):
results = self.db_client.search(
collection_name=COLLECTION_NAME,
query_vector=query_vector,
limit=top_k,
score_threshold=score_threshold
)
else:
results = self.db_client.query_points(
collection_name=COLLECTION_NAME,
query=query_vector,
limit=top_k,
with_payload=True,
score_threshold=score_threshold
).points
return results
def format_context(self, search_results):
context_pieces = []
sources_summary = []
for idx, hit in enumerate(search_results, 1):
raw_source = hit.payload['metadata']['source']
filename = raw_source.split('/')[-1]
text = hit.payload['text']
score = hit.score
sources_summary.append(f"`{filename}` (Score: {score:.2f})")
piece = f"""<doc id="{idx}" source="{filename}">\n{text}\n</doc>"""
context_pieces.append(piece)
return "\n\n".join(context_pieces), sources_summary
rag_system = None
def initialize_system():
global rag_system
if rag_system is None:
try:
rag_system = HFRAG()
except Exception as e:
print(f"Error initializing: {e}")
return None
return rag_system
# ================= Gradio Logic =================
def predict(message, history):
rag = initialize_system()
if not rag:
yield "β System initialization failed. Check logs."
return
if not API_KEY:
yield "β Error: `DEEPSEEK_API_KEY` not set in Space secrets."
return
# 1. Retrieve
yield "π Retrieving relevant documents..."
results = rag.retrieve(message)
if not results:
yield "β οΈ No relevant documents found in the knowledge base."
return
# 2. Format context
context_str, sources_list = rag.format_context(results)
# 3. Build Prompt
system_prompt = """You are an expert AI assistant specializing in the Hugging Face Transformers library.
Your goal is to answer the user's question based ONLY on the provided "Retrieved Context".
GUIDELINES:
1. **Code First**: Prioritize showing Python code examples.
2. **Citation**: Cite source filenames like `[model_doc.md]`.
3. **Honesty**: If the answer isn't in the context, say you don't know.
4. **Format**: Use Markdown."""
user_prompt = f"""### User Query\n{message}\n\n### Retrieved Context\n{context_str}"""
header = "**π Found relevant documents:**\n" + "\n".join([f"- {s}" for s in sources_list]) + "\n\n---\n\n"
current_response = header
yield current_response
try:
response = rag.llm_client.chat.completions.create(
model="deepseek-chat",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
temperature=0.1,
stream=True
)
for chunk in response:
if chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
current_response += content
yield current_response
except Exception as e:
yield current_response + f"\n\nβ LLM API Error: {str(e)}"
demo = gr.ChatInterface(
fn=predict,
title="π€ Hugging Face RAG Expert",
description="Ask me anything about Transformers! Powered by DeepSeek-V3 & Finetuned Embeddings.",
examples=[
"How to implement padding?",
"How to use BERT pipeline?",
"How to fine-tune a model using Trainer?",
"What is the difference between padding and truncation?"
],
theme="soft"
)
if __name__ == "__main__":
demo.launch() |