Merge branch 'main' of https://huggingface.co/spaces/NLP-Debater-Project/FastAPI-Backend-Models
Browse files- config.py +4 -0
- main.py +39 -3
- models/mcp_models.py +35 -1
- requirements.txt +4 -0
- routes/mcp_routes.py +131 -1
- services/mcp_service.py +29 -1
- topic_similarity_google_example.py +0 -182
- topic_similarity_langchain_example.py +0 -54
config.py
CHANGED
|
@@ -42,6 +42,9 @@ GROQ_TTS_FORMAT = "wav"
|
|
| 42 |
# **Chat Model**
|
| 43 |
GROQ_CHAT_MODEL = "llama3-70b-8192"
|
| 44 |
|
|
|
|
|
|
|
|
|
|
| 45 |
# ============ SUPABASE ============
|
| 46 |
SUPABASE_URL = os.getenv("SUPABASE_URL", "")
|
| 47 |
SUPABASE_KEY = os.getenv("SUPABASE_KEY", "")
|
|
@@ -87,6 +90,7 @@ logger.info(f" HF Label Model : {HUGGINGFACE_LABEL_MODEL_ID}")
|
|
| 87 |
logger.info(f" GROQ STT Model : {GROQ_STT_MODEL}")
|
| 88 |
logger.info(f" GROQ TTS Model : {GROQ_TTS_MODEL}")
|
| 89 |
logger.info(f" GROQ Chat Model : {GROQ_CHAT_MODEL}")
|
|
|
|
| 90 |
logger.info(f" Google API Key : {'✓ Configured' if GOOGLE_API_KEY else '✗ Not configured'}")
|
| 91 |
logger.info(f" Supabase URL : {'✓ Configured' if SUPABASE_URL else '✗ Not configured'}")
|
| 92 |
logger.info("="*60)
|
|
|
|
| 42 |
# **Chat Model**
|
| 43 |
GROQ_CHAT_MODEL = "llama3-70b-8192"
|
| 44 |
|
| 45 |
+
# **Topic Extraction Model**
|
| 46 |
+
GROQ_TOPIC_MODEL = "llama-3.3-70b-versatile" # Latest production model, fallback: "llama3-70b-8192"
|
| 47 |
+
|
| 48 |
# ============ SUPABASE ============
|
| 49 |
SUPABASE_URL = os.getenv("SUPABASE_URL", "")
|
| 50 |
SUPABASE_KEY = os.getenv("SUPABASE_KEY", "")
|
|
|
|
| 90 |
logger.info(f" GROQ STT Model : {GROQ_STT_MODEL}")
|
| 91 |
logger.info(f" GROQ TTS Model : {GROQ_TTS_MODEL}")
|
| 92 |
logger.info(f" GROQ Chat Model : {GROQ_CHAT_MODEL}")
|
| 93 |
+
logger.info(f" GROQ Topic Model: {GROQ_TOPIC_MODEL}")
|
| 94 |
logger.info(f" Google API Key : {'✓ Configured' if GOOGLE_API_KEY else '✗ Not configured'}")
|
| 95 |
logger.info(f" Supabase URL : {'✓ Configured' if SUPABASE_URL else '✗ Not configured'}")
|
| 96 |
logger.info("="*60)
|
main.py
CHANGED
|
@@ -55,10 +55,12 @@ def cleanup_on_exit():
|
|
| 55 |
stance_model_manager = None
|
| 56 |
kpa_model_manager = None
|
| 57 |
generate_model_manager = None
|
|
|
|
| 58 |
try:
|
| 59 |
from services.stance_model_manager import stance_model_manager
|
| 60 |
from services.label_model_manager import kpa_model_manager
|
| 61 |
from services.generate_model_manager import generate_model_manager
|
|
|
|
| 62 |
logger.info("✓ Gestionnaires de modèles importés")
|
| 63 |
except ImportError as e:
|
| 64 |
logger.warning(f"⚠ Impossible d'importer les gestionnaires de modèles: {e}")
|
|
@@ -103,6 +105,18 @@ async def lifespan(app: FastAPI):
|
|
| 103 |
logger.error(f"✗ Failed to load Generation model: {str(e)}")
|
| 104 |
logger.error("⚠️ Generation endpoints will not work!")
|
| 105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
logger.info("✓ API startup complete")
|
| 107 |
logger.info("https://nlp-debater-project-fastapi-backend-models.hf.space/docs")
|
| 108 |
|
|
@@ -149,7 +163,11 @@ async def lifespan(app: FastAPI):
|
|
| 149 |
logger.info(f" STT Model: {GROQ_STT_MODEL}")
|
| 150 |
logger.info(f" TTS Model: {GROQ_TTS_MODEL}")
|
| 151 |
logger.info(f" Chat Model: {GROQ_CHAT_MODEL}")
|
|
|
|
|
|
|
| 152 |
logger.info(f" MCP: {'Activé' if MCP_ENABLED else 'Désactivé'}")
|
|
|
|
|
|
|
| 153 |
logger.info("="*60)
|
| 154 |
|
| 155 |
yield
|
|
@@ -262,12 +280,21 @@ async def health():
|
|
| 262 |
"stt": GROQ_STT_MODEL if GROQ_API_KEY else "disabled",
|
| 263 |
"tts": GROQ_TTS_MODEL if GROQ_API_KEY else "disabled",
|
| 264 |
"chat": GROQ_CHAT_MODEL if GROQ_API_KEY else "disabled",
|
|
|
|
|
|
|
| 265 |
"stance_model": "loaded" if (stance_model_manager and hasattr(stance_model_manager, 'model_loaded') and stance_model_manager.model_loaded) else "not loaded",
|
| 266 |
"kpa_model": "loaded" if (kpa_model_manager and hasattr(kpa_model_manager, 'model_loaded') and kpa_model_manager.model_loaded) else "not loaded",
|
|
|
|
| 267 |
"mcp": "enabled" if MCP_ENABLED else "disabled"
|
| 268 |
},
|
| 269 |
"endpoints": {
|
| 270 |
-
"mcp": "/api/v1/mcp" if MCP_ENABLED else "disabled"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
}
|
| 272 |
}
|
| 273 |
return health_status
|
|
@@ -285,13 +312,22 @@ async def not_found_handler(request, exc):
|
|
| 285 |
"GET /health": "Health check",
|
| 286 |
"POST /api/v1/stt/": "Speech to text",
|
| 287 |
"POST /api/v1/tts/": "Text to speech",
|
| 288 |
-
"POST /voice-chat/voice": "Voice chat"
|
|
|
|
|
|
|
| 289 |
}
|
| 290 |
if MCP_ENABLED:
|
| 291 |
endpoints.update({
|
| 292 |
"GET /api/v1/mcp/health": "Health check MCP",
|
| 293 |
"GET /api/v1/mcp/tools": "Liste outils MCP",
|
| 294 |
-
"POST /api/v1/mcp/tools/call": "Appel d'outil MCP"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
})
|
| 296 |
return {
|
| 297 |
"error": "Not Found",
|
|
|
|
| 55 |
stance_model_manager = None
|
| 56 |
kpa_model_manager = None
|
| 57 |
generate_model_manager = None
|
| 58 |
+
topic_similarity_service = None
|
| 59 |
try:
|
| 60 |
from services.stance_model_manager import stance_model_manager
|
| 61 |
from services.label_model_manager import kpa_model_manager
|
| 62 |
from services.generate_model_manager import generate_model_manager
|
| 63 |
+
from services.topic_similarity_service import topic_similarity_service
|
| 64 |
logger.info("✓ Gestionnaires de modèles importés")
|
| 65 |
except ImportError as e:
|
| 66 |
logger.warning(f"⚠ Impossible d'importer les gestionnaires de modèles: {e}")
|
|
|
|
| 105 |
logger.error(f"✗ Failed to load Generation model: {str(e)}")
|
| 106 |
logger.error("⚠️ Generation endpoints will not work!")
|
| 107 |
|
| 108 |
+
# Initialize Topic Extraction service (uses Groq LLM)
|
| 109 |
+
if topic_similarity_service and GROQ_API_KEY:
|
| 110 |
+
try:
|
| 111 |
+
logger.info("Initializing Topic Extraction service (Groq LLM)...")
|
| 112 |
+
topic_similarity_service.initialize()
|
| 113 |
+
logger.info("✓ Topic Extraction service initialized")
|
| 114 |
+
except Exception as e:
|
| 115 |
+
logger.error(f"✗ Failed to initialize Topic Extraction service: {str(e)}")
|
| 116 |
+
logger.error("⚠️ Topic extraction endpoints will not work!")
|
| 117 |
+
elif not GROQ_API_KEY:
|
| 118 |
+
logger.warning("⚠ GROQ_API_KEY not configured. Topic extraction service will not be available.")
|
| 119 |
+
|
| 120 |
logger.info("✓ API startup complete")
|
| 121 |
logger.info("https://nlp-debater-project-fastapi-backend-models.hf.space/docs")
|
| 122 |
|
|
|
|
| 163 |
logger.info(f" STT Model: {GROQ_STT_MODEL}")
|
| 164 |
logger.info(f" TTS Model: {GROQ_TTS_MODEL}")
|
| 165 |
logger.info(f" Chat Model: {GROQ_CHAT_MODEL}")
|
| 166 |
+
logger.info(f" Topic Extraction: {'Initialized' if (topic_similarity_service and topic_similarity_service.initialized) else 'Not initialized'}")
|
| 167 |
+
logger.info(f" Voice Chat: {'Available' if GROQ_API_KEY else 'Disabled (no GROQ_API_KEY)'}")
|
| 168 |
logger.info(f" MCP: {'Activé' if MCP_ENABLED else 'Désactivé'}")
|
| 169 |
+
if MCP_ENABLED:
|
| 170 |
+
logger.info(f" - Tools: detect_stance, match_keypoint_argument, transcribe_audio, generate_speech, generate_argument, extract_topic, voice_chat, health_check")
|
| 171 |
logger.info("="*60)
|
| 172 |
|
| 173 |
yield
|
|
|
|
| 280 |
"stt": GROQ_STT_MODEL if GROQ_API_KEY else "disabled",
|
| 281 |
"tts": GROQ_TTS_MODEL if GROQ_API_KEY else "disabled",
|
| 282 |
"chat": GROQ_CHAT_MODEL if GROQ_API_KEY else "disabled",
|
| 283 |
+
"topic_extraction": "initialized" if (topic_similarity_service and hasattr(topic_similarity_service, 'initialized') and topic_similarity_service.initialized) else "not initialized",
|
| 284 |
+
"voice_chat": "available" if GROQ_API_KEY else "disabled",
|
| 285 |
"stance_model": "loaded" if (stance_model_manager and hasattr(stance_model_manager, 'model_loaded') and stance_model_manager.model_loaded) else "not loaded",
|
| 286 |
"kpa_model": "loaded" if (kpa_model_manager and hasattr(kpa_model_manager, 'model_loaded') and kpa_model_manager.model_loaded) else "not loaded",
|
| 287 |
+
"generate_model": "loaded" if (generate_model_manager and hasattr(generate_model_manager, 'model_loaded') and generate_model_manager.model_loaded) else "not loaded",
|
| 288 |
"mcp": "enabled" if MCP_ENABLED else "disabled"
|
| 289 |
},
|
| 290 |
"endpoints": {
|
| 291 |
+
"mcp": "/api/v1/mcp" if MCP_ENABLED else "disabled",
|
| 292 |
+
"topic_extraction": "/api/v1/topic/extract",
|
| 293 |
+
"voice_chat": "/voice-chat/voice or /voice-chat/text",
|
| 294 |
+
"mcp_tools": {
|
| 295 |
+
"extract_topic": "/api/v1/mcp/tools/extract-topic",
|
| 296 |
+
"voice_chat": "/api/v1/mcp/tools/voice-chat"
|
| 297 |
+
} if MCP_ENABLED else "disabled"
|
| 298 |
}
|
| 299 |
}
|
| 300 |
return health_status
|
|
|
|
| 312 |
"GET /health": "Health check",
|
| 313 |
"POST /api/v1/stt/": "Speech to text",
|
| 314 |
"POST /api/v1/tts/": "Text to speech",
|
| 315 |
+
"POST /voice-chat/voice": "Voice chat (audio input)",
|
| 316 |
+
"POST /voice-chat/text": "Voice chat (text input)",
|
| 317 |
+
"POST /api/v1/topic/extract": "Extract topic from text"
|
| 318 |
}
|
| 319 |
if MCP_ENABLED:
|
| 320 |
endpoints.update({
|
| 321 |
"GET /api/v1/mcp/health": "Health check MCP",
|
| 322 |
"GET /api/v1/mcp/tools": "Liste outils MCP",
|
| 323 |
+
"POST /api/v1/mcp/tools/call": "Appel d'outil MCP",
|
| 324 |
+
"POST /api/v1/mcp/tools/extract-topic": "Extract topic (MCP tool)",
|
| 325 |
+
"POST /api/v1/mcp/tools/voice-chat": "Voice chat (MCP tool)",
|
| 326 |
+
"POST /api/v1/mcp/tools/detect-stance": "Detect stance (MCP tool)",
|
| 327 |
+
"POST /api/v1/mcp/tools/match-keypoint": "Match keypoint (MCP tool)",
|
| 328 |
+
"POST /api/v1/mcp/tools/transcribe-audio": "Transcribe audio (MCP tool)",
|
| 329 |
+
"POST /api/v1/mcp/tools/generate-speech": "Generate speech (MCP tool)",
|
| 330 |
+
"POST /api/v1/mcp/tools/generate-argument": "Generate argument (MCP tool)"
|
| 331 |
})
|
| 332 |
return {
|
| 333 |
"error": "Not Found",
|
models/mcp_models.py
CHANGED
|
@@ -15,7 +15,7 @@ class ToolCallRequest(BaseModel):
|
|
| 15 |
}
|
| 16 |
)
|
| 17 |
|
| 18 |
-
tool_name: str = Field(..., description="Name of the MCP tool to call (e.g., 'detect_stance', 'match_keypoint_argument', 'transcribe_audio', 'generate_speech', 'generate_argument')")
|
| 19 |
arguments: Dict[str, Any] = Field(default_factory=dict, description="Arguments for the tool (varies by tool)")
|
| 20 |
|
| 21 |
class ToolCallResponse(BaseModel):
|
|
@@ -105,6 +105,40 @@ class GenerateSpeechResponse(BaseModel):
|
|
| 105 |
|
| 106 |
audio_path: str = Field(..., description="Path to generated audio file")
|
| 107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
class ResourceInfo(BaseModel):
|
| 109 |
"""Information about an MCP resource"""
|
| 110 |
uri: str
|
|
|
|
| 15 |
}
|
| 16 |
)
|
| 17 |
|
| 18 |
+
tool_name: str = Field(..., description="Name of the MCP tool to call (e.g., 'detect_stance', 'match_keypoint_argument', 'transcribe_audio', 'generate_speech', 'generate_argument', 'extract_topic', 'voice_chat')")
|
| 19 |
arguments: Dict[str, Any] = Field(default_factory=dict, description="Arguments for the tool (varies by tool)")
|
| 20 |
|
| 21 |
class ToolCallResponse(BaseModel):
|
|
|
|
| 105 |
|
| 106 |
audio_path: str = Field(..., description="Path to generated audio file")
|
| 107 |
|
| 108 |
+
class ExtractTopicResponse(BaseModel):
|
| 109 |
+
"""Response model for topic extraction"""
|
| 110 |
+
model_config = ConfigDict(
|
| 111 |
+
json_schema_extra={
|
| 112 |
+
"example": {
|
| 113 |
+
"text": "Governments should subsidize electric cars to encourage adoption.",
|
| 114 |
+
"topic": "government subsidies for electric vehicle adoption",
|
| 115 |
+
"timestamp": "2024-01-01T12:00:00"
|
| 116 |
+
}
|
| 117 |
+
}
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
text: str = Field(..., description="The input text")
|
| 121 |
+
topic: str = Field(..., description="The extracted topic")
|
| 122 |
+
timestamp: Optional[str] = Field(None, description="Timestamp of extraction")
|
| 123 |
+
|
| 124 |
+
class VoiceChatResponse(BaseModel):
|
| 125 |
+
"""Response model for voice chat"""
|
| 126 |
+
model_config = ConfigDict(
|
| 127 |
+
json_schema_extra={
|
| 128 |
+
"example": {
|
| 129 |
+
"user_input": "What is climate change?",
|
| 130 |
+
"conversation_id": "uuid-here",
|
| 131 |
+
"response": "Climate change refers to long-term changes in global temperatures and weather patterns.",
|
| 132 |
+
"timestamp": "2024-01-01T12:00:00"
|
| 133 |
+
}
|
| 134 |
+
}
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
user_input: str = Field(..., description="The user's input text")
|
| 138 |
+
conversation_id: Optional[str] = Field(None, description="The conversation ID")
|
| 139 |
+
response: str = Field(..., description="The chatbot's response")
|
| 140 |
+
timestamp: Optional[str] = Field(None, description="Timestamp of response")
|
| 141 |
+
|
| 142 |
class ResourceInfo(BaseModel):
|
| 143 |
"""Information about an MCP resource"""
|
| 144 |
uri: str
|
requirements.txt
CHANGED
|
@@ -16,6 +16,10 @@ langchain-core>=0.1.0
|
|
| 16 |
langchain-groq>=0.1.0
|
| 17 |
langsmith>=0.1.0
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
# Audio processing (optionnel si vous avez besoin de traitement local)
|
| 20 |
soundfile>=0.12.1
|
| 21 |
|
|
|
|
| 16 |
langchain-groq>=0.1.0
|
| 17 |
langsmith>=0.1.0
|
| 18 |
|
| 19 |
+
# Fix urllib3 compatibility issues
|
| 20 |
+
urllib3>=1.26.0,<3.0.0
|
| 21 |
+
requests-toolbelt>=1.0.0
|
| 22 |
+
|
| 23 |
# Audio processing (optionnel si vous avez besoin de traitement local)
|
| 24 |
soundfile>=0.12.1
|
| 25 |
|
routes/mcp_routes.py
CHANGED
|
@@ -14,6 +14,8 @@ from services.mcp_service import mcp_server
|
|
| 14 |
from services.stance_model_manager import stance_model_manager
|
| 15 |
from services.label_model_manager import kpa_model_manager
|
| 16 |
from services.generate_model_manager import generate_model_manager
|
|
|
|
|
|
|
| 17 |
from models.mcp_models import (
|
| 18 |
ToolListResponse,
|
| 19 |
ToolInfo,
|
|
@@ -22,7 +24,9 @@ from models.mcp_models import (
|
|
| 22 |
DetectStanceResponse,
|
| 23 |
MatchKeypointResponse,
|
| 24 |
TranscribeAudioResponse,
|
| 25 |
-
GenerateSpeechResponse
|
|
|
|
|
|
|
| 26 |
)
|
| 27 |
from models.generate import GenerateRequest, GenerateResponse
|
| 28 |
from datetime import datetime
|
|
@@ -75,6 +79,30 @@ class GenerateSpeechRequest(BaseModel):
|
|
| 75 |
}
|
| 76 |
}
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
|
| 80 |
# ===== Routes MCP =====
|
|
@@ -90,6 +118,8 @@ async def mcp_health():
|
|
| 90 |
"transcribe_audio",
|
| 91 |
"generate_speech",
|
| 92 |
"generate_argument",
|
|
|
|
|
|
|
| 93 |
"health_check"
|
| 94 |
]
|
| 95 |
return {
|
|
@@ -167,6 +197,29 @@ async def list_mcp_tools():
|
|
| 167 |
"required": ["topic", "position"]
|
| 168 |
}
|
| 169 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
ToolInfo(
|
| 171 |
name="health_check",
|
| 172 |
description="Health check pour le serveur MCP",
|
|
@@ -244,6 +297,27 @@ async def call_mcp_tool(request: ToolCallRequest):
|
|
| 244 |
}
|
| 245 |
}
|
| 246 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
"""
|
| 248 |
try:
|
| 249 |
result = await mcp_server.call_tool(request.tool_name, request.arguments)
|
|
@@ -510,6 +584,62 @@ async def mcp_generate_argument(request: GenerateRequest):
|
|
| 510 |
logger.error(f"Error in generate_argument: {e}", exc_info=True)
|
| 511 |
raise HTTPException(status_code=500, detail=f"Error executing tool generate_argument: {e}")
|
| 512 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 513 |
@router.get("/tools/health-check", summary="Health check MCP (outil)")
|
| 514 |
async def mcp_tool_health_check() -> Dict[str, Any]:
|
| 515 |
"""Health check via l'outil MCP"""
|
|
|
|
| 14 |
from services.stance_model_manager import stance_model_manager
|
| 15 |
from services.label_model_manager import kpa_model_manager
|
| 16 |
from services.generate_model_manager import generate_model_manager
|
| 17 |
+
from services.topic_service import topic_service
|
| 18 |
+
from services.chat_service import generate_chat_response
|
| 19 |
from models.mcp_models import (
|
| 20 |
ToolListResponse,
|
| 21 |
ToolInfo,
|
|
|
|
| 24 |
DetectStanceResponse,
|
| 25 |
MatchKeypointResponse,
|
| 26 |
TranscribeAudioResponse,
|
| 27 |
+
GenerateSpeechResponse,
|
| 28 |
+
ExtractTopicResponse,
|
| 29 |
+
VoiceChatResponse
|
| 30 |
)
|
| 31 |
from models.generate import GenerateRequest, GenerateResponse
|
| 32 |
from datetime import datetime
|
|
|
|
| 79 |
}
|
| 80 |
}
|
| 81 |
|
| 82 |
+
class ExtractTopicRequest(BaseModel):
|
| 83 |
+
"""Request pour extraire un topic d'un texte"""
|
| 84 |
+
text: str = Field(..., min_length=5, max_length=5000, description="Le texte/argument à partir duquel extraire le topic")
|
| 85 |
+
|
| 86 |
+
class Config:
|
| 87 |
+
json_schema_extra = {
|
| 88 |
+
"example": {
|
| 89 |
+
"text": "Governments should subsidize electric cars to encourage adoption."
|
| 90 |
+
}
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
class VoiceChatRequest(BaseModel):
|
| 94 |
+
"""Request pour générer une réponse de chatbot vocal"""
|
| 95 |
+
user_input: str = Field(..., description="L'entrée utilisateur (en anglais)")
|
| 96 |
+
conversation_id: Optional[str] = Field(None, description="ID de conversation pour maintenir le contexte")
|
| 97 |
+
|
| 98 |
+
class Config:
|
| 99 |
+
json_schema_extra = {
|
| 100 |
+
"example": {
|
| 101 |
+
"user_input": "What is climate change?",
|
| 102 |
+
"conversation_id": "optional-conversation-id"
|
| 103 |
+
}
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
|
| 107 |
|
| 108 |
# ===== Routes MCP =====
|
|
|
|
| 118 |
"transcribe_audio",
|
| 119 |
"generate_speech",
|
| 120 |
"generate_argument",
|
| 121 |
+
"extract_topic",
|
| 122 |
+
"voice_chat",
|
| 123 |
"health_check"
|
| 124 |
]
|
| 125 |
return {
|
|
|
|
| 197 |
"required": ["topic", "position"]
|
| 198 |
}
|
| 199 |
),
|
| 200 |
+
ToolInfo(
|
| 201 |
+
name="extract_topic",
|
| 202 |
+
description="Extrait un topic à partir d'un texte/argument donné",
|
| 203 |
+
input_schema={
|
| 204 |
+
"type": "object",
|
| 205 |
+
"properties": {
|
| 206 |
+
"text": {"type": "string", "description": "Le texte/argument à partir duquel extraire le topic"}
|
| 207 |
+
},
|
| 208 |
+
"required": ["text"]
|
| 209 |
+
}
|
| 210 |
+
),
|
| 211 |
+
ToolInfo(
|
| 212 |
+
name="voice_chat",
|
| 213 |
+
description="Génère une réponse de chatbot vocal en anglais",
|
| 214 |
+
input_schema={
|
| 215 |
+
"type": "object",
|
| 216 |
+
"properties": {
|
| 217 |
+
"user_input": {"type": "string", "description": "L'entrée utilisateur (en anglais)"},
|
| 218 |
+
"conversation_id": {"type": "string", "description": "ID de conversation pour maintenir le contexte (optionnel)"}
|
| 219 |
+
},
|
| 220 |
+
"required": ["user_input"]
|
| 221 |
+
}
|
| 222 |
+
),
|
| 223 |
ToolInfo(
|
| 224 |
name="health_check",
|
| 225 |
description="Health check pour le serveur MCP",
|
|
|
|
| 297 |
}
|
| 298 |
}
|
| 299 |
```
|
| 300 |
+
|
| 301 |
+
6. **extract_topic** - Extraire un topic d'un texte:
|
| 302 |
+
```json
|
| 303 |
+
{
|
| 304 |
+
"tool_name": "extract_topic",
|
| 305 |
+
"arguments": {
|
| 306 |
+
"text": "Governments should subsidize electric cars to encourage adoption."
|
| 307 |
+
}
|
| 308 |
+
}
|
| 309 |
+
```
|
| 310 |
+
|
| 311 |
+
7. **voice_chat** - Générer une réponse de chatbot vocal:
|
| 312 |
+
```json
|
| 313 |
+
{
|
| 314 |
+
"tool_name": "voice_chat",
|
| 315 |
+
"arguments": {
|
| 316 |
+
"user_input": "What is climate change?",
|
| 317 |
+
"conversation_id": "optional-conversation-id"
|
| 318 |
+
}
|
| 319 |
+
}
|
| 320 |
+
```
|
| 321 |
"""
|
| 322 |
try:
|
| 323 |
result = await mcp_server.call_tool(request.tool_name, request.arguments)
|
|
|
|
| 584 |
logger.error(f"Error in generate_argument: {e}", exc_info=True)
|
| 585 |
raise HTTPException(status_code=500, detail=f"Error executing tool generate_argument: {e}")
|
| 586 |
|
| 587 |
+
@router.post("/tools/extract-topic", response_model=ExtractTopicResponse, summary="Extraire un topic d'un texte")
|
| 588 |
+
async def mcp_extract_topic(request: ExtractTopicRequest):
|
| 589 |
+
"""Extrait un topic à partir d'un texte/argument donné"""
|
| 590 |
+
try:
|
| 591 |
+
# Vérifier que le service est initialisé
|
| 592 |
+
if not topic_service.initialized:
|
| 593 |
+
topic_service.initialize()
|
| 594 |
+
|
| 595 |
+
# Appeler directement le service (plus fiable que via MCP)
|
| 596 |
+
topic_text = topic_service.extract_topic(request.text)
|
| 597 |
+
|
| 598 |
+
# Construire la réponse structurée
|
| 599 |
+
response = ExtractTopicResponse(
|
| 600 |
+
text=request.text,
|
| 601 |
+
topic=topic_text,
|
| 602 |
+
timestamp=datetime.now().isoformat()
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
logger.info(f"Topic extracted from text '{request.text[:50]}...': {topic_text[:50]}...")
|
| 606 |
+
return response
|
| 607 |
+
|
| 608 |
+
except ValueError as e:
|
| 609 |
+
logger.error(f"Validation error in extract_topic: {str(e)}")
|
| 610 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 611 |
+
except Exception as e:
|
| 612 |
+
logger.error(f"Error in extract_topic: {e}", exc_info=True)
|
| 613 |
+
raise HTTPException(status_code=500, detail=f"Error executing tool extract_topic: {e}")
|
| 614 |
+
|
| 615 |
+
@router.post("/tools/voice-chat", response_model=VoiceChatResponse, summary="Générer une réponse de chatbot vocal")
|
| 616 |
+
async def mcp_voice_chat(request: VoiceChatRequest):
|
| 617 |
+
"""Génère une réponse de chatbot vocal en anglais"""
|
| 618 |
+
try:
|
| 619 |
+
# Appeler directement le service (plus fiable que via MCP)
|
| 620 |
+
response_text = generate_chat_response(
|
| 621 |
+
user_input=request.user_input,
|
| 622 |
+
conversation_id=request.conversation_id
|
| 623 |
+
)
|
| 624 |
+
|
| 625 |
+
# Construire la réponse structurée
|
| 626 |
+
response = VoiceChatResponse(
|
| 627 |
+
user_input=request.user_input,
|
| 628 |
+
conversation_id=request.conversation_id,
|
| 629 |
+
response=response_text,
|
| 630 |
+
timestamp=datetime.now().isoformat()
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
logger.info(f"Voice chat response generated for input '{request.user_input[:50]}...': {response_text[:50]}...")
|
| 634 |
+
return response
|
| 635 |
+
|
| 636 |
+
except ValueError as e:
|
| 637 |
+
logger.error(f"Validation error in voice_chat: {str(e)}")
|
| 638 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 639 |
+
except Exception as e:
|
| 640 |
+
logger.error(f"Error in voice_chat: {e}", exc_info=True)
|
| 641 |
+
raise HTTPException(status_code=500, detail=f"Error executing tool voice_chat: {e}")
|
| 642 |
+
|
| 643 |
@router.get("/tools/health-check", summary="Health check MCP (outil)")
|
| 644 |
async def mcp_tool_health_check() -> Dict[str, Any]:
|
| 645 |
"""Health check via l'outil MCP"""
|
services/mcp_service.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
"""Service pour initialiser le serveur MCP avec FastMCP"""
|
| 2 |
|
| 3 |
from mcp.server.fastmcp import FastMCP
|
| 4 |
-
from typing import Dict, Any
|
| 5 |
import logging
|
| 6 |
|
| 7 |
from fastapi import FastAPI
|
|
@@ -11,6 +11,8 @@ from services.label_model_manager import kpa_model_manager
|
|
| 11 |
from services.stt_service import speech_to_text
|
| 12 |
from services.tts_service import text_to_speech
|
| 13 |
from services.generate_model_manager import generate_model_manager
|
|
|
|
|
|
|
| 14 |
|
| 15 |
logger = logging.getLogger(__name__)
|
| 16 |
|
|
@@ -62,6 +64,30 @@ def generate_argument(topic: str, position: str) -> Dict[str, Any]:
|
|
| 62 |
"argument": argument
|
| 63 |
}
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
@mcp_server.resource("debate://prompt")
|
| 66 |
def get_debate_prompt() -> str:
|
| 67 |
return "Tu es un expert en débat. Génère 3 arguments PRO pour le topic donné. Sois concis et persuasif."
|
|
@@ -78,6 +104,8 @@ def health_check() -> Dict[str, Any]:
|
|
| 78 |
"transcribe_audio",
|
| 79 |
"generate_speech",
|
| 80 |
"generate_argument",
|
|
|
|
|
|
|
| 81 |
"health_check"
|
| 82 |
]
|
| 83 |
except Exception:
|
|
|
|
| 1 |
"""Service pour initialiser le serveur MCP avec FastMCP"""
|
| 2 |
|
| 3 |
from mcp.server.fastmcp import FastMCP
|
| 4 |
+
from typing import Dict, Any, Optional
|
| 5 |
import logging
|
| 6 |
|
| 7 |
from fastapi import FastAPI
|
|
|
|
| 11 |
from services.stt_service import speech_to_text
|
| 12 |
from services.tts_service import text_to_speech
|
| 13 |
from services.generate_model_manager import generate_model_manager
|
| 14 |
+
from services.topic_service import topic_service
|
| 15 |
+
from services.chat_service import generate_chat_response
|
| 16 |
|
| 17 |
logger = logging.getLogger(__name__)
|
| 18 |
|
|
|
|
| 64 |
"argument": argument
|
| 65 |
}
|
| 66 |
|
| 67 |
+
@mcp_server.tool()
|
| 68 |
+
def extract_topic(text: str) -> Dict[str, Any]:
|
| 69 |
+
"""Extract a topic from the given text/argument"""
|
| 70 |
+
if not topic_service.initialized:
|
| 71 |
+
topic_service.initialize()
|
| 72 |
+
topic = topic_service.extract_topic(text)
|
| 73 |
+
return {
|
| 74 |
+
"text": text,
|
| 75 |
+
"topic": topic
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
@mcp_server.tool()
|
| 79 |
+
def voice_chat(user_input: str, conversation_id: Optional[str] = None) -> Dict[str, Any]:
|
| 80 |
+
"""Generate a chatbot response for voice chat (English only)"""
|
| 81 |
+
response_text = generate_chat_response(
|
| 82 |
+
user_input=user_input,
|
| 83 |
+
conversation_id=conversation_id
|
| 84 |
+
)
|
| 85 |
+
return {
|
| 86 |
+
"user_input": user_input,
|
| 87 |
+
"conversation_id": conversation_id,
|
| 88 |
+
"response": response_text
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
@mcp_server.resource("debate://prompt")
|
| 92 |
def get_debate_prompt() -> str:
|
| 93 |
return "Tu es un expert en débat. Génère 3 arguments PRO pour le topic donné. Sois concis et persuasif."
|
|
|
|
| 104 |
"transcribe_audio",
|
| 105 |
"generate_speech",
|
| 106 |
"generate_argument",
|
| 107 |
+
"extract_topic",
|
| 108 |
+
"voice_chat",
|
| 109 |
"health_check"
|
| 110 |
]
|
| 111 |
except Exception:
|
topic_similarity_google_example.py
DELETED
|
@@ -1,182 +0,0 @@
|
|
| 1 |
-
from datetime import datetime
|
| 2 |
-
import os
|
| 3 |
-
import json
|
| 4 |
-
import hashlib
|
| 5 |
-
from pathlib import Path
|
| 6 |
-
from dotenv import load_dotenv
|
| 7 |
-
from google import genai
|
| 8 |
-
from google.genai import types
|
| 9 |
-
import numpy as np
|
| 10 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
| 11 |
-
|
| 12 |
-
# Load environment variables from .env file
|
| 13 |
-
load_dotenv()
|
| 14 |
-
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
|
| 15 |
-
if not GOOGLE_API_KEY:
|
| 16 |
-
raise ValueError("GOOGLE_API_KEY is not set in environment variables.")
|
| 17 |
-
|
| 18 |
-
# Get the path to topics.json relative to this file
|
| 19 |
-
TOPICS_FILE = Path(__file__).parent.parent / "data" / "topics.json"
|
| 20 |
-
# Cache file for topic embeddings
|
| 21 |
-
EMBEDDINGS_CACHE_FILE = Path(__file__).parent.parent / "data" / "topic_embeddings_cache.json"
|
| 22 |
-
|
| 23 |
-
# Create a Google Generative AI client with the API key
|
| 24 |
-
client = genai.Client(api_key=GOOGLE_API_KEY)
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
def load_topics():
|
| 28 |
-
"""Load topics from topics.json file."""
|
| 29 |
-
with open(TOPICS_FILE, 'r', encoding='utf-8') as f:
|
| 30 |
-
data = json.load(f)
|
| 31 |
-
return data.get("topics", [])
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def get_topics_hash(topics):
|
| 35 |
-
"""Generate a hash of the topics list to verify cache validity."""
|
| 36 |
-
topics_str = json.dumps(topics, sort_keys=True)
|
| 37 |
-
return hashlib.md5(topics_str.encode('utf-8')).hexdigest()
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
def load_cached_embeddings():
|
| 41 |
-
"""Load cached topic embeddings if they exist and are valid."""
|
| 42 |
-
if not EMBEDDINGS_CACHE_FILE.exists():
|
| 43 |
-
return None
|
| 44 |
-
|
| 45 |
-
try:
|
| 46 |
-
with open(EMBEDDINGS_CACHE_FILE, 'r', encoding='utf-8') as f:
|
| 47 |
-
cache_data = json.load(f)
|
| 48 |
-
|
| 49 |
-
# Verify cache is valid by checking topics hash
|
| 50 |
-
current_topics = load_topics()
|
| 51 |
-
current_hash = get_topics_hash(current_topics)
|
| 52 |
-
|
| 53 |
-
if cache_data.get("topics_hash") == current_hash:
|
| 54 |
-
# Convert list embeddings back to numpy arrays
|
| 55 |
-
embeddings = [np.array(emb) for emb in cache_data.get("embeddings", [])]
|
| 56 |
-
return embeddings
|
| 57 |
-
else:
|
| 58 |
-
# Topics have changed, cache is invalid
|
| 59 |
-
return None
|
| 60 |
-
except (json.JSONDecodeError, KeyError, ValueError) as e:
|
| 61 |
-
# Cache file is corrupted or invalid format
|
| 62 |
-
print(f"Warning: Could not load cached embeddings: {e}")
|
| 63 |
-
return None
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
def save_cached_embeddings(embeddings, topics):
|
| 67 |
-
"""Save topic embeddings to cache file."""
|
| 68 |
-
topics_hash = get_topics_hash(topics)
|
| 69 |
-
|
| 70 |
-
# Convert numpy arrays to lists for JSON serialization
|
| 71 |
-
embeddings_list = [emb.tolist() for emb in embeddings]
|
| 72 |
-
|
| 73 |
-
cache_data = {
|
| 74 |
-
"topics_hash": topics_hash,
|
| 75 |
-
"embeddings": embeddings_list,
|
| 76 |
-
"model": "models/text-embedding-004",
|
| 77 |
-
"cached_at": datetime.now().isoformat()
|
| 78 |
-
}
|
| 79 |
-
|
| 80 |
-
try:
|
| 81 |
-
with open(EMBEDDINGS_CACHE_FILE, 'w', encoding='utf-8') as f:
|
| 82 |
-
json.dump(cache_data, f, indent=2)
|
| 83 |
-
print(f"Cached {len(embeddings)} topic embeddings to {EMBEDDINGS_CACHE_FILE}")
|
| 84 |
-
except Exception as e:
|
| 85 |
-
print(f"Warning: Could not save cached embeddings: {e}")
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
def get_topic_embeddings():
|
| 89 |
-
"""
|
| 90 |
-
Get topic embeddings, loading from cache if available, otherwise generating and caching them.
|
| 91 |
-
|
| 92 |
-
Returns:
|
| 93 |
-
numpy.ndarray: Array of topic embeddings
|
| 94 |
-
"""
|
| 95 |
-
topics = load_topics()
|
| 96 |
-
|
| 97 |
-
# Try to load from cache first
|
| 98 |
-
cached_embeddings = load_cached_embeddings()
|
| 99 |
-
if cached_embeddings is not None:
|
| 100 |
-
print(f"Loaded {len(cached_embeddings)} topic embeddings from cache")
|
| 101 |
-
return np.array(cached_embeddings)
|
| 102 |
-
|
| 103 |
-
# Cache miss or invalid - generate embeddings
|
| 104 |
-
print(f"Generating embeddings for {len(topics)} topics (this may take a moment)...")
|
| 105 |
-
embedding_response = client.models.embed_content(
|
| 106 |
-
model="models/text-embedding-004",
|
| 107 |
-
contents=topics,
|
| 108 |
-
config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY")
|
| 109 |
-
)
|
| 110 |
-
|
| 111 |
-
if not hasattr(embedding_response, "embeddings") or embedding_response.embeddings is None:
|
| 112 |
-
raise RuntimeError("Embedding API did not return embeddings.")
|
| 113 |
-
|
| 114 |
-
embeddings = [np.array(e.values) for e in embedding_response.embeddings]
|
| 115 |
-
|
| 116 |
-
# Save to cache for future use
|
| 117 |
-
save_cached_embeddings(embeddings, topics)
|
| 118 |
-
|
| 119 |
-
return np.array(embeddings)
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
def find_most_similar_topic(input_text: str):
|
| 123 |
-
"""
|
| 124 |
-
Compare a single input text to all topics and return the highest cosine similarity.
|
| 125 |
-
Uses cached topic embeddings to avoid re-embedding topics on every call.
|
| 126 |
-
|
| 127 |
-
Args:
|
| 128 |
-
input_text: The text to compare against topics
|
| 129 |
-
|
| 130 |
-
Returns:
|
| 131 |
-
dict: Contains 'topic', 'similarity', and 'index' of the most similar topic
|
| 132 |
-
"""
|
| 133 |
-
# Load topics from JSON file
|
| 134 |
-
topics = load_topics()
|
| 135 |
-
|
| 136 |
-
if not topics:
|
| 137 |
-
raise ValueError("No topics found in topics.json")
|
| 138 |
-
|
| 139 |
-
# Get topic embeddings (from cache or generate)
|
| 140 |
-
topic_embeddings = get_topic_embeddings()
|
| 141 |
-
|
| 142 |
-
# Only embed the input text (much faster!)
|
| 143 |
-
embedding_response = client.models.embed_content(
|
| 144 |
-
model="models/text-embedding-004",
|
| 145 |
-
contents=[input_text],
|
| 146 |
-
config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY")
|
| 147 |
-
)
|
| 148 |
-
|
| 149 |
-
if not hasattr(embedding_response, "embeddings") or embedding_response.embeddings is None:
|
| 150 |
-
raise RuntimeError("Embedding API did not return embeddings.")
|
| 151 |
-
|
| 152 |
-
# Extract input embedding
|
| 153 |
-
input_embedding = np.array(embedding_response.embeddings[0].values).reshape(1, -1)
|
| 154 |
-
|
| 155 |
-
# Calculate cosine similarity between input and each topic
|
| 156 |
-
similarities = cosine_similarity(input_embedding, topic_embeddings)[0]
|
| 157 |
-
|
| 158 |
-
# Find the highest similarity
|
| 159 |
-
max_index = np.argmax(similarities)
|
| 160 |
-
max_similarity = similarities[max_index]
|
| 161 |
-
most_similar_topic = topics[max_index]
|
| 162 |
-
|
| 163 |
-
return {
|
| 164 |
-
"topic": most_similar_topic,
|
| 165 |
-
"similarity": float(max_similarity),
|
| 166 |
-
"index": int(max_index)
|
| 167 |
-
}
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
if __name__ == "__main__":
|
| 171 |
-
# Example usage
|
| 172 |
-
#start time
|
| 173 |
-
start_time = datetime.now()
|
| 174 |
-
test_text = "we should abandon the use of school uniform since one should be allowed to express their individuality by the clothes they were."
|
| 175 |
-
result = find_most_similar_topic(test_text)
|
| 176 |
-
print(f"Input text: '{test_text}'")
|
| 177 |
-
print(f"Most similar topic: '{result['topic']}'")
|
| 178 |
-
print(f"Cosine similarity: {result['similarity']:.4f}%")
|
| 179 |
-
#end time
|
| 180 |
-
end_time = datetime.now()
|
| 181 |
-
#in seconds
|
| 182 |
-
print(f"Time taken: {(end_time - start_time).total_seconds()} seconds")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
topic_similarity_langchain_example.py
DELETED
|
@@ -1,54 +0,0 @@
|
|
| 1 |
-
import json
|
| 2 |
-
import os
|
| 3 |
-
from datetime import datetime
|
| 4 |
-
from dotenv import load_dotenv
|
| 5 |
-
load_dotenv()
|
| 6 |
-
|
| 7 |
-
from langchain_community.vectorstores import FAISS
|
| 8 |
-
from langchain_core.example_selectors import (
|
| 9 |
-
SemanticSimilarityExampleSelector,
|
| 10 |
-
)
|
| 11 |
-
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
| 12 |
-
|
| 13 |
-
# Load topics from data file
|
| 14 |
-
with open(
|
| 15 |
-
file="data/topics.json",
|
| 16 |
-
encoding="utf-8"
|
| 17 |
-
) as f:
|
| 18 |
-
data = json.load(f)
|
| 19 |
-
|
| 20 |
-
# Make sure each example is a dict with "topic" key (wrap as dict if plain string)
|
| 21 |
-
def format_examples(examples):
|
| 22 |
-
formatted = []
|
| 23 |
-
for ex in examples:
|
| 24 |
-
if isinstance(ex, str):
|
| 25 |
-
formatted.append({"topic": ex})
|
| 26 |
-
elif isinstance(ex, dict) and "topic" in ex:
|
| 27 |
-
formatted.append({"topic": ex["topic"]})
|
| 28 |
-
else:
|
| 29 |
-
formatted.append({"topic": str(ex)})
|
| 30 |
-
return formatted
|
| 31 |
-
|
| 32 |
-
# topics.json should have a top-level "topics" key
|
| 33 |
-
examples = data.get("topics", [])
|
| 34 |
-
formatted_examples = format_examples(examples)
|
| 35 |
-
|
| 36 |
-
start_time = datetime.now()
|
| 37 |
-
example_selector = SemanticSimilarityExampleSelector.from_examples(
|
| 38 |
-
examples=formatted_examples,
|
| 39 |
-
embeddings=GoogleGenerativeAIEmbeddings(
|
| 40 |
-
model="models/text-embedding-004",
|
| 41 |
-
api_key=os.getenv("GOOGLE_API_KEY")
|
| 42 |
-
),
|
| 43 |
-
vectorstore_cls=FAISS,
|
| 44 |
-
k=1,
|
| 45 |
-
input_keys=["topic"],
|
| 46 |
-
)
|
| 47 |
-
|
| 48 |
-
# Example call to selector (for demonstration; remove in production)
|
| 49 |
-
result = example_selector.select_examples(
|
| 50 |
-
{"topic": "people who are terminally ill and suffering greatly should have the right to end their own life if they so desire."}
|
| 51 |
-
)
|
| 52 |
-
print(result)
|
| 53 |
-
end_time = datetime.now()
|
| 54 |
-
print(f"Time taken: {(end_time - start_time).total_seconds()} seconds")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|