Spaces:
Sleeping
Sleeping
Commit
·
2f235a0
1
Parent(s):
6818d41
working the rag and web server
Browse files- .gitignore +2 -1
- README.md +1 -1
- backend/api/main.py +42 -1
- backend/api/mcp_clients/mcp_client.py +26 -0
- backend/api/models/__init__.py +18 -0
- backend/api/models/agent.py +24 -0
- backend/api/models/redflag.py +23 -0
- backend/api/routes/agent.py +44 -27
- backend/api/services/agent_orchestrator.py +257 -64
- backend/api/services/intent_classifier.py +31 -22
- backend/api/services/llm_client.py +47 -24
- backend/api/services/redflag_detector.py +158 -49
- backend/api/services/tool_selector.py +159 -26
- backend/mcp_servers/admin_server.py +51 -0
- backend/mcp_servers/models/__init__.py +18 -0
- backend/mcp_servers/models/admin.py +14 -0
- backend/mcp_servers/models/rag.py +12 -0
- backend/mcp_servers/models/web.py +7 -0
- backend/mcp_servers/rag_server.py +60 -0
- backend/mcp_servers/web_server.py +71 -0
- backend/tests/conftest.py +1 -0
- backend/tests/test_agent_orchestrator.py +208 -9
- backend/tests/test_intent.py +97 -26
- env.example +41 -0
- pytest.ini +7 -0
- requirements.txt +4 -1
.gitignore
CHANGED
|
@@ -1,2 +1,3 @@
|
|
| 1 |
venv/
|
| 2 |
-
.env
|
|
|
|
|
|
| 1 |
venv/
|
| 2 |
+
.env
|
| 3 |
+
.pytest_cache
|
README.md
CHANGED
|
@@ -441,7 +441,7 @@ docker-compose up -d
|
|
| 441 |
| **Category** | Enterprise |
|
| 442 |
| **Tag** | `mcp-in-action-track-enterprise` |
|
| 443 |
| **Project Name** | **IntegraChat** |
|
| 444 |
-
|
| 445 |
### Short Summary
|
| 446 |
|
| 447 |
> IntegraChat is a multi-tenant AI platform where autonomous MCP-powered agents retrieve private knowledge using RAG, access live web information, and enforce admin-defined safety rules via a red-flag compliance system. It includes an analytics dashboard, tool-selection engine, and strict tenant isolation.
|
|
|
|
| 441 |
| **Category** | Enterprise |
|
| 442 |
| **Tag** | `mcp-in-action-track-enterprise` |
|
| 443 |
| **Project Name** | **IntegraChat** |
|
| 444 |
+
|
| 445 |
### Short Summary
|
| 446 |
|
| 447 |
> IntegraChat is a multi-tenant AI platform where autonomous MCP-powered agents retrieve private knowledge using RAG, access live web information, and enforce admin-defined safety rules via a red-flag compliance system. It includes an analytics dashboard, tool-selection engine, and strict tenant isolation.
|
backend/api/main.py
CHANGED
|
@@ -1,7 +1,28 @@
|
|
| 1 |
from fastapi import FastAPI
|
| 2 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
app.add_middleware(
|
| 7 |
CORSMiddleware,
|
|
@@ -11,6 +32,26 @@ app.add_middleware(
|
|
| 11 |
allow_headers=["*"],
|
| 12 |
)
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
@app.get("/health")
|
| 15 |
def health():
|
| 16 |
return {"status": "ok", "version": "1.0.0"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from fastapi import FastAPI
|
| 2 |
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
from pathlib import Path
|
| 6 |
|
| 7 |
+
# ------------------------------------------------------------
|
| 8 |
+
# Fix Python paths so imports like backend.api.routes.agent work
|
| 9 |
+
# ------------------------------------------------------------
|
| 10 |
+
root_dir = Path(__file__).resolve().parents[2]
|
| 11 |
+
sys.path.insert(0, str(root_dir))
|
| 12 |
+
|
| 13 |
+
# ------------------------------------------------------------
|
| 14 |
+
# Import ALL routers correctly
|
| 15 |
+
# ------------------------------------------------------------
|
| 16 |
+
from backend.api.routes.agent import router as agent_router
|
| 17 |
+
from backend.api.routes.admin import router as admin_router
|
| 18 |
+
from backend.api.routes.rag import router as rag_router
|
| 19 |
+
from backend.api.routes.web import router as web_router
|
| 20 |
+
from backend.api.routes.analytics import router as analytics_router
|
| 21 |
+
|
| 22 |
+
# ------------------------------------------------------------
|
| 23 |
+
# Main FastAPI app
|
| 24 |
+
# ------------------------------------------------------------
|
| 25 |
+
app = FastAPI(title="IntegraChat API", version="1.0.0")
|
| 26 |
|
| 27 |
app.add_middleware(
|
| 28 |
CORSMiddleware,
|
|
|
|
| 32 |
allow_headers=["*"],
|
| 33 |
)
|
| 34 |
|
| 35 |
+
# ------------------------------------------------------------
|
| 36 |
+
# Route Registration (THIS FIXES YOUR 404)
|
| 37 |
+
# ------------------------------------------------------------
|
| 38 |
+
app.include_router(agent_router, prefix="/agent", tags=["Agent"])
|
| 39 |
+
app.include_router(admin_router, prefix="/admin", tags=["Admin"])
|
| 40 |
+
app.include_router(rag_router, prefix="/rag", tags=["RAG"])
|
| 41 |
+
app.include_router(web_router, prefix="/web", tags=["Web"])
|
| 42 |
+
app.include_router(analytics_router, prefix="/analytics", tags=["Analytics"])
|
| 43 |
+
|
| 44 |
+
# ------------------------------------------------------------
|
| 45 |
+
# Health Check
|
| 46 |
+
# ------------------------------------------------------------
|
| 47 |
@app.get("/health")
|
| 48 |
def health():
|
| 49 |
return {"status": "ok", "version": "1.0.0"}
|
| 50 |
+
|
| 51 |
+
# ------------------------------------------------------------
|
| 52 |
+
# Local Run
|
| 53 |
+
# ------------------------------------------------------------
|
| 54 |
+
if __name__ == "__main__":
|
| 55 |
+
import uvicorn
|
| 56 |
+
port = int(os.getenv("API_PORT", "8000"))
|
| 57 |
+
uvicorn.run("backend.api.main:app", host="0.0.0.0", port=port, reload=True)
|
backend/api/mcp_clients/mcp_client.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import httpx
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@dataclass
|
| 6 |
+
class MCPClient:
|
| 7 |
+
rag_url: str
|
| 8 |
+
web_url: str
|
| 9 |
+
admin_url: str
|
| 10 |
+
client: httpx.AsyncClient = field(default_factory=lambda: httpx.AsyncClient(timeout=30))
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
async def call_rag(self, tenant_id: str, query: str):
|
| 14 |
+
r = await self.client.post(f"{self.rag_url}/search", json={"tenant_id":tenant_id,"query":query})
|
| 15 |
+
return r.json()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
async def call_web(self, tenant_id: str, query: str):
|
| 19 |
+
r = await self.client.post(f"{self.web_url}/search", json={"tenant_id":tenant_id,"query":query})
|
| 20 |
+
return r.json()
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
async def call_admin(self, tenant_id: str, query: str):
|
| 24 |
+
r = await self.client.post(f"{self.admin_url}/eval", json={"tenant_id":tenant_id,"query":query})
|
| 25 |
+
return r.json()
|
| 26 |
+
|
backend/api/models/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
"""
|
| 3 |
+
API Models Package
|
| 4 |
+
|
| 5 |
+
This package contains all Pydantic and dataclass models used across the API.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from .agent import AgentRequest, AgentDecision, AgentResponse
|
| 9 |
+
from .redflag import RedFlagRule, RedFlagMatch
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"AgentRequest",
|
| 13 |
+
"AgentDecision",
|
| 14 |
+
"AgentResponse",
|
| 15 |
+
"RedFlagRule",
|
| 16 |
+
"RedFlagMatch",
|
| 17 |
+
]
|
| 18 |
+
|
backend/api/models/agent.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Any, List
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class AgentRequest(BaseModel):
|
| 6 |
+
tenant_id: str
|
| 7 |
+
user_id: str | None
|
| 8 |
+
message: str
|
| 9 |
+
conversation_history: List[Dict[str, str]] = []
|
| 10 |
+
temperature: float = 0.0
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class AgentDecision(BaseModel):
|
| 14 |
+
action: str
|
| 15 |
+
tool: str | None
|
| 16 |
+
tool_input: Dict[str, Any] | None
|
| 17 |
+
reason: str | None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class AgentResponse(BaseModel):
|
| 21 |
+
text: str
|
| 22 |
+
decision: AgentDecision
|
| 23 |
+
tool_traces: List[Dict[str, Any]] = []
|
| 24 |
+
|
backend/api/models/redflag.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@dataclass
|
| 6 |
+
class RedFlagRule:
|
| 7 |
+
id: str
|
| 8 |
+
pattern: str
|
| 9 |
+
description: str
|
| 10 |
+
severity: str # e.g., "low", "medium", "high", "critical"
|
| 11 |
+
source: str # "admin", "system"
|
| 12 |
+
enabled: bool = True
|
| 13 |
+
keywords: List[str] = field(default_factory=list)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class RedFlagMatch:
|
| 18 |
+
rule_id: str
|
| 19 |
+
pattern: str
|
| 20 |
+
severity: str
|
| 21 |
+
description: str
|
| 22 |
+
matched_text: str
|
| 23 |
+
|
backend/api/routes/agent.py
CHANGED
|
@@ -1,30 +1,47 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from api.services.agent_orchestrator import AgentOrchestrator
|
|
|
|
|
|
|
| 3 |
|
| 4 |
router = APIRouter()
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
)
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================
|
| 2 |
+
# File: backend/api/routes/agent.py
|
| 3 |
+
# =============================================================
|
| 4 |
+
|
| 5 |
+
from fastapi import APIRouter
|
| 6 |
+
from pydantic import BaseModel
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
# Add backend to path for imports
|
| 12 |
+
backend_dir = Path(__file__).parent.parent.parent
|
| 13 |
+
sys.path.insert(0, str(backend_dir))
|
| 14 |
+
|
| 15 |
from api.services.agent_orchestrator import AgentOrchestrator
|
| 16 |
+
from api.models.agent import AgentRequest, AgentResponse
|
| 17 |
+
|
| 18 |
|
| 19 |
router = APIRouter()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
orchestrator = AgentOrchestrator(
|
| 23 |
+
rag_mcp_url=os.getenv("RAG_MCP_URL", "http://localhost:8001"),
|
| 24 |
+
web_mcp_url=os.getenv("WEB_MCP_URL", "http://localhost:8002"),
|
| 25 |
+
admin_mcp_url=os.getenv("ADMIN_MCP_URL", "http://localhost:8003"),
|
| 26 |
+
llm_backend=os.getenv("LLM_BACKEND", "ollama")
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class ChatRequest(BaseModel):
|
| 31 |
+
tenant_id: str
|
| 32 |
+
user_id: str | None = None
|
| 33 |
+
message: str
|
| 34 |
+
conversation_history: list[dict] = []
|
| 35 |
+
temperature: float = 0.0
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@router.post("/message", response_model=AgentResponse)
|
| 39 |
+
async def agent_chat(req: ChatRequest):
|
| 40 |
+
agent_req = AgentRequest(
|
| 41 |
+
tenant_id=req.tenant_id,
|
| 42 |
+
user_id=req.user_id,
|
| 43 |
+
message=req.message,
|
| 44 |
+
conversation_history=req.conversation_history,
|
| 45 |
+
temperature=req.temperature
|
| 46 |
+
)
|
| 47 |
+
return await orchestrator.handle(agent_req)
|
backend/api/services/agent_orchestrator.py
CHANGED
|
@@ -1,82 +1,275 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
class AgentOrchestrator:
|
| 13 |
|
| 14 |
-
def __init__(self):
|
| 15 |
-
|
| 16 |
-
self.
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
self.
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
self.admin_client = AdminClient()
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
-
# 2
|
| 36 |
-
|
| 37 |
-
redflag = self.redflag_detector.check(user_message, tenant_rules)
|
| 38 |
|
| 39 |
-
#
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
-
# Tool
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
-
|
| 47 |
-
if tool == "rag":
|
| 48 |
-
rag_results = await self.rag_client.search(user_message, tenant_id)
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
-
|
| 54 |
-
|
| 55 |
try:
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
)
|
| 69 |
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
"
|
| 81 |
-
"
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================
|
| 2 |
+
# File: backend/api/services/agent_orchestrator.py
|
| 3 |
+
# =============================================================
|
| 4 |
+
"""
|
| 5 |
+
Agent Orchestrator (integrated with enterprise RedFlagDetector)
|
| 6 |
|
| 7 |
+
Place at: backend/api/services/agent_orchestrator.py
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import json
|
| 13 |
+
import os
|
| 14 |
+
from typing import List, Dict, Any, Optional
|
| 15 |
+
|
| 16 |
+
from ..models.agent import AgentRequest, AgentDecision, AgentResponse
|
| 17 |
+
from ..models.redflag import RedFlagMatch
|
| 18 |
+
from .redflag_detector import RedFlagDetector
|
| 19 |
+
from .intent_classifier import IntentClassifier
|
| 20 |
+
from .tool_selector import ToolSelector
|
| 21 |
+
from .llm_client import LLMClient
|
| 22 |
+
from ..mcp_clients.mcp_client import MCPClient
|
| 23 |
|
| 24 |
|
| 25 |
class AgentOrchestrator:
|
| 26 |
|
| 27 |
+
def __init__(self, rag_mcp_url: str, web_mcp_url: str, admin_mcp_url: str, llm_backend: str = "ollama"):
|
| 28 |
+
self.mcp = MCPClient(rag_mcp_url, web_mcp_url, admin_mcp_url)
|
| 29 |
+
self.llm = LLMClient(backend=llm_backend, url=os.getenv("OLLAMA_URL"), api_key=os.getenv("GROQ_API_KEY"), model=os.getenv("OLLAMA_MODEL"))
|
| 30 |
+
|
| 31 |
+
# pass admin_mcp_url so detector can call back
|
| 32 |
+
self.redflag = RedFlagDetector(
|
| 33 |
+
supabase_url=os.getenv("SUPABASE_URL"),
|
| 34 |
+
supabase_key=os.getenv("SUPABASE_SERVICE_KEY"),
|
| 35 |
+
admin_mcp_url=admin_mcp_url
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
self.intent = IntentClassifier(llm_client=self.llm)
|
| 39 |
+
self.selector = ToolSelector(llm_client=self.llm)
|
| 40 |
|
| 41 |
+
async def handle(self, req: AgentRequest) -> AgentResponse:
|
| 42 |
+
# 1) Red-flag check (async)
|
| 43 |
+
matches: List[RedFlagMatch] = await self.redflag.check(req.tenant_id, req.message)
|
|
|
|
| 44 |
|
| 45 |
+
if matches:
|
| 46 |
+
# Notify admin asynchronously (do not await blocking the response path if you prefer)
|
| 47 |
+
# we await here to ensure admin receives the alert before responding
|
| 48 |
+
try:
|
| 49 |
+
await self.redflag.notify_admin(req.tenant_id, matches, source_payload={"message": req.message, "user_id": req.user_id})
|
| 50 |
+
except Exception:
|
| 51 |
+
pass
|
| 52 |
|
| 53 |
+
decision = AgentDecision(
|
| 54 |
+
action="block",
|
| 55 |
+
tool="admin",
|
| 56 |
+
tool_input={"violations": [m.__dict__ for m in matches]},
|
| 57 |
+
reason="redflag_triggered"
|
| 58 |
+
)
|
| 59 |
+
return AgentResponse(
|
| 60 |
+
text="Your request has been blocked due to policy.",
|
| 61 |
+
decision=decision,
|
| 62 |
+
tool_traces=[{"redflags": [m.__dict__ for m in matches]}]
|
| 63 |
+
)
|
| 64 |
|
| 65 |
+
# 2) Intent classification
|
| 66 |
+
intent = await self.intent.classify(req.message)
|
|
|
|
| 67 |
|
| 68 |
+
# 2.5) Pre-fetch RAG results if available (for tool selector context)
|
| 69 |
+
rag_prefetch = None
|
| 70 |
+
rag_results = []
|
| 71 |
+
try:
|
| 72 |
+
# Try to pre-fetch RAG to help tool selector make better decisions
|
| 73 |
+
rag_prefetch = await self.mcp.call_rag(req.tenant_id, req.message)
|
| 74 |
+
if isinstance(rag_prefetch, dict):
|
| 75 |
+
rag_results = rag_prefetch.get("results") or rag_prefetch.get("hits") or []
|
| 76 |
+
except Exception:
|
| 77 |
+
# If RAG fails, continue without it
|
| 78 |
+
pass
|
| 79 |
|
| 80 |
+
# 3) Tool selection (hybrid) - pass RAG results in context
|
| 81 |
+
ctx = {
|
| 82 |
+
"tenant_id": req.tenant_id,
|
| 83 |
+
"rag_results": rag_results
|
| 84 |
+
}
|
| 85 |
+
decision = await self.selector.select(intent, req.message, ctx)
|
| 86 |
|
| 87 |
+
tool_traces: List[Dict[str, Any]] = []
|
|
|
|
|
|
|
| 88 |
|
| 89 |
+
# 4) Handle multi-step tool execution
|
| 90 |
+
if decision.action == "multi_step" and decision.tool_input:
|
| 91 |
+
steps = decision.tool_input.get("steps", [])
|
| 92 |
+
if steps:
|
| 93 |
+
return await self._execute_multi_step(req, steps, decision, tool_traces, rag_prefetch)
|
| 94 |
|
| 95 |
+
# 5) Execute single tool
|
| 96 |
+
if decision.action == "call_tool" and decision.tool:
|
| 97 |
try:
|
| 98 |
+
if decision.tool == "rag":
|
| 99 |
+
rag_resp = await self.mcp.call_rag(req.tenant_id, decision.tool_input.get("query") if decision.tool_input else req.message)
|
| 100 |
+
tool_traces.append({"tool": "rag", "response": rag_resp})
|
| 101 |
+
prompt = self._build_prompt_with_rag(req, rag_resp)
|
| 102 |
+
llm_out = await self.llm.simple_call(prompt, temperature=req.temperature)
|
| 103 |
+
return AgentResponse(text=llm_out, decision=decision, tool_traces=tool_traces)
|
| 104 |
+
|
| 105 |
+
if decision.tool == "web":
|
| 106 |
+
web_resp = await self.mcp.call_web(req.tenant_id, decision.tool_input.get("query") if decision.tool_input else req.message)
|
| 107 |
+
tool_traces.append({"tool": "web", "response": web_resp})
|
| 108 |
+
prompt = self._build_prompt_with_web(req, web_resp)
|
| 109 |
+
llm_out = await self.llm.simple_call(prompt, temperature=req.temperature)
|
| 110 |
+
return AgentResponse(text=llm_out, decision=decision, tool_traces=tool_traces)
|
| 111 |
+
|
| 112 |
+
if decision.tool == "admin":
|
| 113 |
+
admin_resp = await self.mcp.call_admin(req.tenant_id, decision.tool_input.get("query") if decision.tool_input else req.message)
|
| 114 |
+
tool_traces.append({"tool": "admin", "response": admin_resp})
|
| 115 |
+
return AgentResponse(text=json.dumps(admin_resp), decision=decision, tool_traces=tool_traces)
|
| 116 |
|
| 117 |
+
if decision.tool == "llm":
|
| 118 |
+
llm_out = await self.llm.simple_call(req.message, temperature=req.temperature)
|
| 119 |
+
return AgentResponse(text=llm_out, decision=decision)
|
| 120 |
+
|
| 121 |
+
except Exception as e:
|
| 122 |
+
tool_traces.append({"tool": decision.tool, "error": str(e)})
|
| 123 |
+
try:
|
| 124 |
+
fallback = await self.llm.simple_call(req.message, temperature=req.temperature)
|
| 125 |
+
except Exception as llm_error:
|
| 126 |
+
fallback = f"I encountered an error while processing your request: {str(e)}. Additionally, the AI service is unavailable: {str(llm_error)}"
|
| 127 |
+
return AgentResponse(
|
| 128 |
+
text=fallback,
|
| 129 |
+
decision=AgentDecision(action="respond", tool=None, tool_input=None, reason=f"tool_error_fallback: {e}"),
|
| 130 |
+
tool_traces=tool_traces
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# Default: direct LLM response
|
| 134 |
+
try:
|
| 135 |
+
llm_out = await self.llm.simple_call(req.message, temperature=req.temperature)
|
| 136 |
+
except Exception as e:
|
| 137 |
+
# If LLM fails, return a helpful error message
|
| 138 |
+
llm_out = f"I apologize, but I'm unable to process your request right now. The AI service is unavailable: {str(e)}"
|
| 139 |
+
|
| 140 |
+
return AgentResponse(
|
| 141 |
+
text=llm_out,
|
| 142 |
+
decision=AgentDecision(action="respond", tool=None, tool_input=None, reason="default_llm")
|
| 143 |
)
|
| 144 |
|
| 145 |
+
def _build_prompt_with_rag(self, req: AgentRequest, rag_resp: Dict[str, Any]) -> str:
|
| 146 |
+
snippets = []
|
| 147 |
+
if isinstance(rag_resp, dict):
|
| 148 |
+
hits = rag_resp.get("results") or rag_resp.get("hits") or []
|
| 149 |
+
for h in hits[:5]:
|
| 150 |
+
txt = h.get("text") or h.get("content") or str(h)
|
| 151 |
+
snippets.append(txt)
|
| 152 |
+
|
| 153 |
+
snippet_text = "\n---\n".join(snippets) or ""
|
| 154 |
+
prompt = (
|
| 155 |
+
f"You are an assistant helping tenant {req.tenant_id}. Use the following retrieved documents to answer the user's question.\n"
|
| 156 |
+
f"Documents:\n{snippet_text}\n\n"
|
| 157 |
+
f"User question: {req.message}\nProvide a concise, accurate answer and cite the source snippets where appropriate."
|
| 158 |
+
)
|
| 159 |
+
return prompt
|
| 160 |
+
|
| 161 |
+
async def _execute_multi_step(self, req: AgentRequest, steps: List[Dict[str, Any]],
|
| 162 |
+
decision: AgentDecision, tool_traces: List[Dict[str, Any]],
|
| 163 |
+
pre_fetched_rag: Optional[Dict[str, Any]] = None) -> AgentResponse:
|
| 164 |
+
"""
|
| 165 |
+
Execute multiple tools in sequence and synthesize results with LLM.
|
| 166 |
+
"""
|
| 167 |
+
rag_data = None
|
| 168 |
+
web_data = None
|
| 169 |
+
admin_data = None
|
| 170 |
+
collected_data = []
|
| 171 |
+
|
| 172 |
+
# Execute each step in sequence
|
| 173 |
+
for step_info in steps:
|
| 174 |
+
tool_name = step_info.get("tool")
|
| 175 |
+
step_input = step_info.get("input") or {}
|
| 176 |
+
query = step_input.get("query") or req.message
|
| 177 |
+
|
| 178 |
+
try:
|
| 179 |
+
if tool_name == "rag":
|
| 180 |
+
# Reuse pre-fetched RAG if available, otherwise fetch
|
| 181 |
+
if pre_fetched_rag:
|
| 182 |
+
rag_resp = pre_fetched_rag
|
| 183 |
+
tool_traces.append({"tool": "rag", "response": rag_resp, "note": "used_pre_fetched"})
|
| 184 |
+
else:
|
| 185 |
+
rag_resp = await self.mcp.call_rag(req.tenant_id, query)
|
| 186 |
+
tool_traces.append({"tool": "rag", "response": rag_resp})
|
| 187 |
+
rag_data = rag_resp
|
| 188 |
+
# Extract snippets for prompt
|
| 189 |
+
if isinstance(rag_resp, dict):
|
| 190 |
+
hits = rag_resp.get("results") or rag_resp.get("hits") or []
|
| 191 |
+
for h in hits[:5]:
|
| 192 |
+
txt = h.get("text") or h.get("content") or str(h)
|
| 193 |
+
collected_data.append(f"[RAG] {txt}")
|
| 194 |
+
|
| 195 |
+
elif tool_name == "web":
|
| 196 |
+
web_resp = await self.mcp.call_web(req.tenant_id, query)
|
| 197 |
+
tool_traces.append({"tool": "web", "response": web_resp})
|
| 198 |
+
web_data = web_resp
|
| 199 |
+
# Extract snippets for prompt
|
| 200 |
+
if isinstance(web_resp, dict):
|
| 201 |
+
hits = web_resp.get("results") or web_resp.get("items") or []
|
| 202 |
+
for h in hits[:5]:
|
| 203 |
+
title = h.get("title") or h.get("headline") or ""
|
| 204 |
+
snippet = h.get("snippet") or h.get("summary") or h.get("text") or ""
|
| 205 |
+
url = h.get("url") or h.get("link") or ""
|
| 206 |
+
collected_data.append(f"[WEB] {title}\n{snippet}\nSource: {url}")
|
| 207 |
+
|
| 208 |
+
elif tool_name == "admin":
|
| 209 |
+
admin_resp = await self.mcp.call_admin(req.tenant_id, query)
|
| 210 |
+
tool_traces.append({"tool": "admin", "response": admin_resp})
|
| 211 |
+
admin_data = admin_resp
|
| 212 |
+
collected_data.append(f"[ADMIN] {json.dumps(admin_resp)}")
|
| 213 |
+
|
| 214 |
+
elif tool_name == "llm":
|
| 215 |
+
# LLM is always last - synthesize all collected data
|
| 216 |
+
break
|
| 217 |
+
|
| 218 |
+
except Exception as e:
|
| 219 |
+
tool_traces.append({"tool": tool_name, "error": str(e)})
|
| 220 |
+
# Continue with other tools even if one fails
|
| 221 |
+
|
| 222 |
+
# Build comprehensive prompt with all collected data
|
| 223 |
+
data_section = "\n---\n".join(collected_data) if collected_data else ""
|
| 224 |
+
|
| 225 |
+
if data_section:
|
| 226 |
+
prompt = (
|
| 227 |
+
f"You are an assistant helping tenant {req.tenant_id}.\n\n"
|
| 228 |
+
f"The following information has been gathered from multiple sources:\n\n"
|
| 229 |
+
f"{data_section}\n\n"
|
| 230 |
+
f"User question: {req.message}\n\n"
|
| 231 |
+
f"Provide a comprehensive, accurate answer using the information above. "
|
| 232 |
+
f"Cite sources where appropriate (RAG for internal docs, WEB for online sources)."
|
| 233 |
+
)
|
| 234 |
+
else:
|
| 235 |
+
# No data collected, just answer the question
|
| 236 |
+
prompt = req.message
|
| 237 |
+
|
| 238 |
+
# Final LLM synthesis
|
| 239 |
+
try:
|
| 240 |
+
llm_out = await self.llm.simple_call(prompt, temperature=req.temperature)
|
| 241 |
+
return AgentResponse(
|
| 242 |
+
text=llm_out,
|
| 243 |
+
decision=decision,
|
| 244 |
+
tool_traces=tool_traces
|
| 245 |
+
)
|
| 246 |
+
except Exception as e:
|
| 247 |
+
tool_traces.append({"tool": "llm", "error": str(e)})
|
| 248 |
+
fallback = f"I encountered an error while synthesizing the response: {str(e)}"
|
| 249 |
+
return AgentResponse(
|
| 250 |
+
text=fallback,
|
| 251 |
+
decision=AgentDecision(
|
| 252 |
+
action="respond",
|
| 253 |
+
tool=None,
|
| 254 |
+
tool_input=None,
|
| 255 |
+
reason=f"multi_step_llm_error: {e}"
|
| 256 |
+
),
|
| 257 |
+
tool_traces=tool_traces
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
def _build_prompt_with_web(self, req: AgentRequest, web_resp: Dict[str, Any]) -> str:
|
| 261 |
+
snippets = []
|
| 262 |
+
if isinstance(web_resp, dict):
|
| 263 |
+
hits = web_resp.get("results") or web_resp.get("items") or []
|
| 264 |
+
for h in hits[:5]:
|
| 265 |
+
title = h.get("title") or h.get("headline") or ""
|
| 266 |
+
snippet = h.get("snippet") or h.get("summary") or h.get("text") or ""
|
| 267 |
+
url = h.get("url") or h.get("link") or ""
|
| 268 |
+
snippets.append(f"{title}\n{snippet}\nSource: {url}")
|
| 269 |
+
|
| 270 |
+
snippet_text = "\n---\n".join(snippets) or ""
|
| 271 |
+
prompt = (
|
| 272 |
+
f"You are an assistant with access to recent web search results. Use the following results to answer.\n{snippet_text}\n\n"
|
| 273 |
+
f"User question: {req.message}\nAnswer succinctly and indicate which results you used."
|
| 274 |
+
)
|
| 275 |
+
return prompt
|
backend/api/services/intent_classifier.py
CHANGED
|
@@ -1,26 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
class IntentClassifier:
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
]
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
-
self.admin_keywords = [
|
| 14 |
-
"delete", "remove", "salary", "confidential",
|
| 15 |
-
"admin", "shutdown", "disable", "breach"
|
| 16 |
-
]
|
| 17 |
|
| 18 |
-
|
| 19 |
-
if
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from typing import Dict, List
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@dataclass
|
| 6 |
class IntentClassifier:
|
| 7 |
+
intent_keywords: Dict[str, List[str]] = field(default_factory=lambda:{
|
| 8 |
+
"rag":["document","policy","manual","procedure","hr"],
|
| 9 |
+
"web":["latest","today","news","current","price","stock"],
|
| 10 |
+
"admin":["delete","remove","export","salary","confidential"],
|
| 11 |
+
"general":["explain","summary","help"]
|
| 12 |
+
})
|
| 13 |
+
llm_client: any = None
|
| 14 |
+
|
| 15 |
|
| 16 |
+
async def classify(self, text: str) -> str:
|
| 17 |
+
t = text.lower()
|
| 18 |
+
scores={k:0 for k in self.intent_keywords}
|
| 19 |
+
for k, words in self.intent_keywords.items():
|
| 20 |
+
for w in words:
|
| 21 |
+
if w in t: scores[k]+=1
|
| 22 |
+
best = max(scores, key=scores.get)
|
| 23 |
+
if scores[best] > 0: return best
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
+
# LLM fallback with error handling
|
| 27 |
+
if self.llm_client:
|
| 28 |
+
try:
|
| 29 |
+
prompt=f"Classify into rag/web/admin/general. User: '{text}'"
|
| 30 |
+
out = (await self.llm_client.simple_call(prompt)).strip().lower()
|
| 31 |
+
return out if out in scores else "general"
|
| 32 |
+
except Exception:
|
| 33 |
+
# LLM failed (not configured or unavailable), default to general
|
| 34 |
+
return "general"
|
| 35 |
+
return "general"
|
backend/api/services/llm_client.py
CHANGED
|
@@ -1,30 +1,53 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import
|
| 3 |
-
from dotenv import load_dotenv
|
| 4 |
|
| 5 |
-
load_dotenv()
|
| 6 |
|
| 7 |
class LLMClient:
|
| 8 |
-
"""
|
| 9 |
-
Uses a LOCAL Llama model via Ollama.
|
| 10 |
-
"""
|
| 11 |
|
| 12 |
-
def __init__(self,
|
| 13 |
-
self.
|
| 14 |
-
|
| 15 |
-
self.
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
def generate(self, prompt: str) -> str:
|
| 18 |
-
payload = {
|
| 19 |
-
"model": self.model,
|
| 20 |
-
"prompt": prompt,
|
| 21 |
-
"stream": False
|
| 22 |
-
}
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, json
|
| 2 |
+
import httpx
|
|
|
|
| 3 |
|
|
|
|
| 4 |
|
| 5 |
class LLMClient:
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
def __init__(self, backend="ollama", url=None, api_key=None, model=None):
|
| 8 |
+
self.backend = backend
|
| 9 |
+
self.url = url or os.getenv("OLLAMA_URL", "http://localhost:11434")
|
| 10 |
+
self.api_key = api_key or os.getenv("GROQ_API_KEY")
|
| 11 |
+
self.model = model or os.getenv("OLLAMA_MODEL", "llama3.1:latest")
|
| 12 |
+
self.http = httpx.AsyncClient(timeout=30)
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
+
async def simple_call(self, prompt: str, temperature: float = 0.0) -> str:
|
| 16 |
+
if self.backend=="ollama":
|
| 17 |
+
if not self.url or not self.model:
|
| 18 |
+
raise RuntimeError(f"LLM not configured: url={self.url}, model={self.model}. Set OLLAMA_URL and OLLAMA_MODEL env vars.")
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
# Ollama uses /api/generate endpoint
|
| 22 |
+
r = await self.http.post(
|
| 23 |
+
f"{self.url}/api/generate",
|
| 24 |
+
json={
|
| 25 |
+
"model": self.model,
|
| 26 |
+
"prompt": prompt,
|
| 27 |
+
"stream": False,
|
| 28 |
+
"options": {"temperature": temperature}
|
| 29 |
+
}
|
| 30 |
+
)
|
| 31 |
+
r.raise_for_status()
|
| 32 |
+
response_data = r.json()
|
| 33 |
+
return response_data.get("response", "")
|
| 34 |
+
except httpx.HTTPStatusError as e:
|
| 35 |
+
if e.response.status_code == 404:
|
| 36 |
+
raise RuntimeError(
|
| 37 |
+
f"Ollama endpoint not found. Is Ollama running at {self.url}? "
|
| 38 |
+
f"Or does the model '{self.model}' exist? "
|
| 39 |
+
f"Try: ollama pull {self.model}"
|
| 40 |
+
)
|
| 41 |
+
elif e.response.status_code == 400:
|
| 42 |
+
error_detail = e.response.json().get("error", "Unknown error")
|
| 43 |
+
raise RuntimeError(f"Ollama API error: {error_detail}")
|
| 44 |
+
else:
|
| 45 |
+
raise RuntimeError(f"Ollama API error: HTTP {e.response.status_code} - {e.response.text}")
|
| 46 |
+
except httpx.ConnectError:
|
| 47 |
+
raise RuntimeError(
|
| 48 |
+
f"Cannot connect to Ollama at {self.url}. "
|
| 49 |
+
f"Is Ollama running? Start it with: ollama serve"
|
| 50 |
+
)
|
| 51 |
+
except Exception as e:
|
| 52 |
+
raise RuntimeError(f"LLM call failed: {str(e)}")
|
| 53 |
+
raise RuntimeError("Unsupported backend")
|
backend/api/services/redflag_detector.py
CHANGED
|
@@ -1,54 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import re
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
class RedFlagDetector:
|
| 4 |
-
def __init__(self):
|
| 5 |
-
# Built-in system red flags
|
| 6 |
-
self.core_flags = [
|
| 7 |
-
"delete all data",
|
| 8 |
-
"wipe database",
|
| 9 |
-
"salary",
|
| 10 |
-
"confidential",
|
| 11 |
-
"password",
|
| 12 |
-
"secret",
|
| 13 |
-
"credential",
|
| 14 |
-
"token",
|
| 15 |
-
"ssh key",
|
| 16 |
-
"api key"
|
| 17 |
-
]
|
| 18 |
-
|
| 19 |
-
# Regex patterns for sensitive info
|
| 20 |
-
self.regex_patterns = {
|
| 21 |
-
"email": r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-z]{2,}",
|
| 22 |
-
"credit_card": r"\b(?:\d[ -]*?){13,16}\b",
|
| 23 |
-
"ssn": r"\b\d{3}-\d{2}-\d{4}\b",
|
| 24 |
-
"api_key": r"(?i)(apikey|api_key|token)[=:]\s?[A-Za-z0-9-_]{10,}"
|
| 25 |
-
}
|
| 26 |
|
| 27 |
-
def
|
| 28 |
-
""
|
| 29 |
-
|
| 30 |
-
""
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================
|
| 2 |
+
# File: backend/api/services/redflag_detector.py
|
| 3 |
+
# =============================================================
|
| 4 |
+
"""
|
| 5 |
+
Enterprise RedFlagDetector
|
| 6 |
+
|
| 7 |
+
- Loads per-tenant rules from Supabase REST (or you can swap to Postgres direct)
|
| 8 |
+
- Caches rules per tenant with TTL
|
| 9 |
+
- Performs regex and keyword matching
|
| 10 |
+
- Returns structured match objects with severity and rule metadata
|
| 11 |
+
- Sends notifications to Admin MCP or a webhook
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import os
|
| 15 |
import re
|
| 16 |
+
import time
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from typing import List, Dict, Any, Optional
|
| 19 |
+
import httpx
|
| 20 |
+
|
| 21 |
+
from ..models.redflag import RedFlagRule, RedFlagMatch
|
| 22 |
+
|
| 23 |
|
| 24 |
class RedFlagDetector:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
+
def __init__(self, supabase_url: Optional[str] = None, supabase_key: Optional[str] = None, admin_mcp_url: Optional[str] = None, cache_ttl: int = 300):
|
| 27 |
+
self.supabase_url = supabase_url or os.getenv("SUPABASE_URL")
|
| 28 |
+
self.supabase_key = supabase_key or os.getenv("SUPABASE_SERVICE_KEY")
|
| 29 |
+
self.admin_mcp_url = admin_mcp_url or os.getenv("ADMIN_MCP_URL")
|
| 30 |
+
self.cache_ttl = cache_ttl
|
| 31 |
+
self._rules_cache: Dict[str, Dict[str, Any]] = {} # tenant_id -> {"fetched_at":ts, "rules":[...]}
|
| 32 |
+
self._client = httpx.AsyncClient(timeout=15)
|
| 33 |
+
|
| 34 |
+
async def _fetch_rules_from_supabase(self, tenant_id: str) -> List[RedFlagRule]:
|
| 35 |
+
# Expecting a table `redflag_rules` with columns: id, tenant_id, pattern, description, severity, source, enabled, keywords (json array)
|
| 36 |
+
if not self.supabase_url or not self.supabase_key:
|
| 37 |
+
return []
|
| 38 |
+
|
| 39 |
+
url = self.supabase_url.rstrip("/") + "/rest/v1/redflag_rules"
|
| 40 |
+
headers = {"apikey": self.supabase_key, "Authorization": f"Bearer {self.supabase_key}"}
|
| 41 |
+
params = {"tenant_id": f"eq.{tenant_id}", "select": "*"}
|
| 42 |
+
|
| 43 |
+
r = await self._client.get(url, headers=headers, params=params)
|
| 44 |
+
r.raise_for_status()
|
| 45 |
+
|
| 46 |
+
rows = r.json()
|
| 47 |
+
|
| 48 |
+
rules: List[RedFlagRule] = []
|
| 49 |
+
|
| 50 |
+
for row in rows:
|
| 51 |
+
try:
|
| 52 |
+
keywords = row.get("keywords") or []
|
| 53 |
+
if isinstance(keywords, str):
|
| 54 |
+
# attempt to parse JSON-encoded string
|
| 55 |
+
try:
|
| 56 |
+
import json
|
| 57 |
+
keywords = json.loads(keywords)
|
| 58 |
+
except Exception:
|
| 59 |
+
keywords = []
|
| 60 |
+
|
| 61 |
+
rules.append(
|
| 62 |
+
RedFlagRule(
|
| 63 |
+
id=str(row.get("id")),
|
| 64 |
+
pattern=row.get("pattern") or "",
|
| 65 |
+
description=row.get("description") or "",
|
| 66 |
+
severity=row.get("severity") or "medium",
|
| 67 |
+
source=row.get("source") or "admin",
|
| 68 |
+
enabled=row.get("enabled", True),
|
| 69 |
+
keywords=keywords or [],
|
| 70 |
+
)
|
| 71 |
+
)
|
| 72 |
+
except Exception:
|
| 73 |
+
# skip invalid rows defensively
|
| 74 |
+
continue
|
| 75 |
+
|
| 76 |
+
return rules
|
| 77 |
+
|
| 78 |
+
async def load_rules(self, tenant_id: str) -> List[RedFlagRule]:
|
| 79 |
+
now = int(time.time())
|
| 80 |
+
entry = self._rules_cache.get(tenant_id)
|
| 81 |
+
|
| 82 |
+
if entry and now - entry["fetched_at"] < self.cache_ttl:
|
| 83 |
+
return entry["rules"]
|
| 84 |
+
|
| 85 |
+
rules = await self._fetch_rules_from_supabase(tenant_id)
|
| 86 |
+
self._rules_cache[tenant_id] = {"fetched_at": now, "rules": rules}
|
| 87 |
+
return rules
|
| 88 |
+
|
| 89 |
+
async def check(self, tenant_id: str, text: str) -> List[RedFlagMatch]:
|
| 90 |
+
"""Return structured matches for the given tenant and text."""
|
| 91 |
+
if not text:
|
| 92 |
+
return []
|
| 93 |
+
|
| 94 |
+
rules = await self.load_rules(tenant_id)
|
| 95 |
+
matches: List[RedFlagMatch] = []
|
| 96 |
+
|
| 97 |
+
text_lower = text.lower()
|
| 98 |
+
|
| 99 |
+
for rule in rules:
|
| 100 |
+
if not rule.enabled:
|
| 101 |
+
continue
|
| 102 |
+
|
| 103 |
+
matched = False
|
| 104 |
+
matched_text = ""
|
| 105 |
+
|
| 106 |
+
# 1) Keyword quick-check (cheap)
|
| 107 |
+
for kw in (rule.keywords or []):
|
| 108 |
+
if kw and kw.lower() in text_lower:
|
| 109 |
+
matched = True
|
| 110 |
+
matched_text = kw
|
| 111 |
+
break
|
| 112 |
+
|
| 113 |
+
# 2) Regex check (more precise)
|
| 114 |
+
if not matched and rule.pattern:
|
| 115 |
+
try:
|
| 116 |
+
pat = re.compile(rule.pattern, re.IGNORECASE)
|
| 117 |
+
m = pat.search(text)
|
| 118 |
+
if m:
|
| 119 |
+
matched = True
|
| 120 |
+
matched_text = m.group(0)
|
| 121 |
+
except re.error:
|
| 122 |
+
# invalid regex; skip this rule
|
| 123 |
+
continue
|
| 124 |
+
|
| 125 |
+
if matched:
|
| 126 |
+
matches.append(
|
| 127 |
+
RedFlagMatch(
|
| 128 |
+
rule_id=rule.id,
|
| 129 |
+
pattern=rule.pattern,
|
| 130 |
+
severity=rule.severity,
|
| 131 |
+
description=rule.description,
|
| 132 |
+
matched_text=matched_text,
|
| 133 |
+
)
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
return matches
|
| 137 |
+
|
| 138 |
+
async def notify_admin(self, tenant_id: str, violations: List[RedFlagMatch], source_payload: Optional[Dict[str, Any]] = None) -> None:
|
| 139 |
+
"""Notify the Admin MCP server (or a webhook) about the matches."""
|
| 140 |
+
payload = {
|
| 141 |
+
"tenant_id": tenant_id,
|
| 142 |
+
"violations": [v.__dict__ for v in violations],
|
| 143 |
+
"source": source_payload or {},
|
| 144 |
}
|
| 145 |
+
|
| 146 |
+
# 1) POST to Admin MCP /alert if configured
|
| 147 |
+
if self.admin_mcp_url:
|
| 148 |
+
try:
|
| 149 |
+
await self._client.post(self.admin_mcp_url.rstrip("/") + "/alert", json=payload, timeout=10)
|
| 150 |
+
except Exception:
|
| 151 |
+
# swallow exceptions — notifications should not crash orchestration
|
| 152 |
+
pass
|
| 153 |
+
|
| 154 |
+
# 2) Optionally send to a Slack/Teams webhook
|
| 155 |
+
webhook = os.getenv("ALERT_WEBHOOK")
|
| 156 |
+
if webhook:
|
| 157 |
+
try:
|
| 158 |
+
await self._client.post(webhook, json={"text": f"Red-flag for tenant {tenant_id}", "details": payload}, timeout=10)
|
| 159 |
+
except Exception:
|
| 160 |
+
pass
|
| 161 |
+
|
| 162 |
+
async def close(self):
|
| 163 |
+
await self._client.aclose()
|
backend/api/services/tool_selector.py
CHANGED
|
@@ -1,27 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
class ToolSelector:
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
import json
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@dataclass
|
| 7 |
class ToolSelector:
|
| 8 |
+
llm_client: any = None
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
async def select(self, intent: str, text: str, ctx):
|
| 12 |
+
msg = text.lower().strip()
|
| 13 |
+
|
| 14 |
+
# ---------------------------------
|
| 15 |
+
# 1. Detect ADMIN RULES FIRST
|
| 16 |
+
# ---------------------------------
|
| 17 |
+
if intent == "admin":
|
| 18 |
+
return _multi_step([
|
| 19 |
+
step("admin", {"query": text}),
|
| 20 |
+
step("llm", {"query": text})
|
| 21 |
+
], "admin safety rule triggered → llm")
|
| 22 |
+
|
| 23 |
+
steps = []
|
| 24 |
+
needs_rag = False
|
| 25 |
+
needs_web = False
|
| 26 |
+
|
| 27 |
+
# ---------------------------------
|
| 28 |
+
# 2. Check RAG results (pre-fetch)
|
| 29 |
+
# ---------------------------------
|
| 30 |
+
rag_results = ctx.get("rag_results", [])
|
| 31 |
+
rag_has_data = len(rag_results) > 0
|
| 32 |
+
|
| 33 |
+
# RAG patterns: internal knowledge, company-specific, documentation
|
| 34 |
+
rag_patterns = [
|
| 35 |
+
r"company", r"internal", r"documentation", r"our ", r"your ",
|
| 36 |
+
r"knowledge base", r"private", r"internal docs", r"corporate"
|
| 37 |
+
]
|
| 38 |
+
if rag_has_data or any(re.search(p, msg) for p in rag_patterns):
|
| 39 |
+
needs_rag = True
|
| 40 |
+
if rag_has_data:
|
| 41 |
+
steps.append(step("rag", {"query": text}))
|
| 42 |
+
|
| 43 |
+
# ---------------------------------
|
| 44 |
+
# 3. Fact lookup / definition → Web
|
| 45 |
+
# ---------------------------------
|
| 46 |
+
fact_patterns = [
|
| 47 |
+
r"what is ", r"who is ", r"where is ",
|
| 48 |
+
r"tell me about ", r"define ", r"explain ",
|
| 49 |
+
r"history of ", r"information about", r"details about"
|
| 50 |
+
]
|
| 51 |
+
if any(re.search(p, msg) for p in fact_patterns):
|
| 52 |
+
needs_web = True
|
| 53 |
+
steps.append(step("web", {"query": text}))
|
| 54 |
+
|
| 55 |
+
# ---------------------------------
|
| 56 |
+
# 4. Freshness heuristic → Web
|
| 57 |
+
# ---------------------------------
|
| 58 |
+
freshness_keywords = ["latest", "today", "news", "current", "recent",
|
| 59 |
+
"now", "updates", "breaking", "trending"]
|
| 60 |
+
if any(k in msg for k in freshness_keywords):
|
| 61 |
+
needs_web = True
|
| 62 |
+
# Avoid duplicate web steps
|
| 63 |
+
if not any(s["tool"] == "web" for s in steps):
|
| 64 |
+
steps.append(step("web", {"query": text}))
|
| 65 |
+
|
| 66 |
+
# ---------------------------------
|
| 67 |
+
# 5. Complex queries that need multiple sources
|
| 68 |
+
# ---------------------------------
|
| 69 |
+
complex_patterns = [
|
| 70 |
+
r"compare", r"difference between", r"versus", r"vs",
|
| 71 |
+
r"both", r"and also", r"as well as", r"in addition"
|
| 72 |
+
]
|
| 73 |
+
needs_multiple = any(re.search(p, msg) for p in complex_patterns)
|
| 74 |
+
|
| 75 |
+
# ---------------------------------
|
| 76 |
+
# 6. Use LLM to enhance plan if we have partial steps or complex query
|
| 77 |
+
# ---------------------------------
|
| 78 |
+
if self.llm_client and (needs_multiple or (needs_rag and needs_web) or len(steps) == 0):
|
| 79 |
+
plan_prompt = f"""
|
| 80 |
+
You are an enterprise MCP agent.
|
| 81 |
+
You can select MULTIPLE tools in sequence to provide comprehensive answers.
|
| 82 |
+
|
| 83 |
+
TOOLS:
|
| 84 |
+
- rag → private knowledge retrieval (use for internal/company docs)
|
| 85 |
+
- web → online factual lookup (use for public facts, current info)
|
| 86 |
+
- llm → final reasoning and synthesis (always include at end)
|
| 87 |
+
|
| 88 |
+
Current context:
|
| 89 |
+
- RAG available: {rag_has_data}
|
| 90 |
+
- User message: "{text}"
|
| 91 |
+
|
| 92 |
+
Determine which tools are needed. You can select:
|
| 93 |
+
- Just LLM (simple questions)
|
| 94 |
+
- RAG + LLM (internal knowledge questions)
|
| 95 |
+
- Web + LLM (public fact questions)
|
| 96 |
+
- RAG + Web + LLM (comprehensive questions needing both sources)
|
| 97 |
+
|
| 98 |
+
Return a JSON list describing the steps, e.g.:
|
| 99 |
+
|
| 100 |
+
[
|
| 101 |
+
{{"tool": "rag", "reason": "Need internal documentation"}},
|
| 102 |
+
{{"tool": "web", "reason": "Need current public information"}},
|
| 103 |
+
{{"tool": "llm", "reason": "Synthesize all information"}}
|
| 104 |
+
]
|
| 105 |
+
|
| 106 |
+
Only return the JSON array. Do not include markdown formatting.
|
| 107 |
+
"""
|
| 108 |
+
try:
|
| 109 |
+
out = await self.llm_client.simple_call(plan_prompt)
|
| 110 |
+
# Clean the output in case LLM adds markdown
|
| 111 |
+
out = out.strip()
|
| 112 |
+
if out.startswith("```json"):
|
| 113 |
+
out = out[7:]
|
| 114 |
+
if out.startswith("```"):
|
| 115 |
+
out = out[3:]
|
| 116 |
+
if out.endswith("```"):
|
| 117 |
+
out = out[:-3]
|
| 118 |
+
out = out.strip()
|
| 119 |
+
|
| 120 |
+
steps_json = json.loads(out)
|
| 121 |
+
|
| 122 |
+
# Replace steps with LLM-planned steps (excluding LLM, we'll add it at end)
|
| 123 |
+
steps = [
|
| 124 |
+
step(s["tool"], {"query": text})
|
| 125 |
+
for s in steps_json if s.get("tool") != "llm"
|
| 126 |
+
]
|
| 127 |
+
except Exception as e:
|
| 128 |
+
# If LLM planning fails, keep existing steps or use fallback
|
| 129 |
+
if not steps:
|
| 130 |
+
steps = []
|
| 131 |
+
|
| 132 |
+
# ---------------------------------
|
| 133 |
+
# 7. Always end with LLM synthesis
|
| 134 |
+
# ---------------------------------
|
| 135 |
+
if not steps or steps[-1]["tool"] != "llm":
|
| 136 |
+
steps.append(step("llm", {
|
| 137 |
+
"rag_data": rag_results if rag_has_data else None,
|
| 138 |
+
"query": text
|
| 139 |
+
}))
|
| 140 |
+
|
| 141 |
+
# Build reason string showing the tool sequence
|
| 142 |
+
tool_names = [s["tool"] for s in steps]
|
| 143 |
+
reason = f"multi-tool plan: {' → '.join(tool_names)}"
|
| 144 |
+
|
| 145 |
+
return _multi_step(steps, reason)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def step(tool, input_data):
|
| 150 |
+
return {"tool": tool, "input": input_data}
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def _multi_step(steps, reason):
|
| 154 |
+
from ..models.agent import AgentDecision
|
| 155 |
+
return AgentDecision(
|
| 156 |
+
action="multi_step",
|
| 157 |
+
tool=None,
|
| 158 |
+
tool_input={"steps": steps},
|
| 159 |
+
reason=reason
|
| 160 |
+
)
|
backend/mcp_servers/admin_server.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================
|
| 2 |
+
# File: backend/mcp_servers/admin_server.py
|
| 3 |
+
# =============================================================
|
| 4 |
+
|
| 5 |
+
from fastapi import FastAPI
|
| 6 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 7 |
+
import logging
|
| 8 |
+
import sys
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
# Fix Python module paths
|
| 12 |
+
current_dir = os.path.dirname(__file__)
|
| 13 |
+
sys.path.insert(0, current_dir)
|
| 14 |
+
|
| 15 |
+
from models.admin import EvalRequest, AlertPayload
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
admin_app = FastAPI(title="Admin MCP Server")
|
| 19 |
+
|
| 20 |
+
# Enable CORS
|
| 21 |
+
admin_app.add_middleware(
|
| 22 |
+
CORSMiddleware,
|
| 23 |
+
allow_origins=["*"],
|
| 24 |
+
allow_credentials=True,
|
| 25 |
+
allow_methods=["*"],
|
| 26 |
+
allow_headers=["*"],
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
log = logging.getLogger("admin_mcp")
|
| 30 |
+
logging.basicConfig(level=logging.INFO)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@admin_app.post("/eval")
|
| 34 |
+
async def eval_query(req: EvalRequest):
|
| 35 |
+
danger = ["delete all data", "export users", "password", "token"]
|
| 36 |
+
q = req.query.lower()
|
| 37 |
+
for d in danger:
|
| 38 |
+
if d in q:
|
| 39 |
+
return {"action": "block", "reason": d}
|
| 40 |
+
return {"action": "allow"}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@admin_app.post("/alert")
|
| 44 |
+
async def alert(payload: AlertPayload):
|
| 45 |
+
log.warning(f"Alert received for tenant {payload.tenant_id}: {payload.violations}")
|
| 46 |
+
return {"status": "ok"}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
import uvicorn
|
| 51 |
+
uvicorn.run(admin_app, host="0.0.0.0", port=8003)
|
backend/mcp_servers/models/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MCP Server Models Package
|
| 3 |
+
|
| 4 |
+
This package contains all Pydantic models used across MCP servers.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from .admin import EvalRequest, AlertPayload
|
| 8 |
+
from .rag import IngestRequest, SearchRequest
|
| 9 |
+
from .web import WebSearchRequest
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"EvalRequest",
|
| 13 |
+
"AlertPayload",
|
| 14 |
+
"IngestRequest",
|
| 15 |
+
"SearchRequest",
|
| 16 |
+
"WebSearchRequest",
|
| 17 |
+
]
|
| 18 |
+
|
backend/mcp_servers/models/admin.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class EvalRequest(BaseModel):
|
| 6 |
+
tenant_id: str
|
| 7 |
+
query: str
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class AlertPayload(BaseModel):
|
| 11 |
+
tenant_id: str
|
| 12 |
+
violations: list
|
| 13 |
+
source: Optional[dict] = None
|
| 14 |
+
|
backend/mcp_servers/models/rag.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class IngestRequest(BaseModel):
|
| 5 |
+
tenant_id: str
|
| 6 |
+
content: str
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class SearchRequest(BaseModel):
|
| 10 |
+
tenant_id: str
|
| 11 |
+
query: str
|
| 12 |
+
|
backend/mcp_servers/models/web.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class WebSearchRequest(BaseModel):
|
| 5 |
+
tenant_id: str
|
| 6 |
+
query: str
|
| 7 |
+
|
backend/mcp_servers/rag_server.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================
|
| 2 |
+
# File: backend/mcp_servers/rag_server.py
|
| 3 |
+
# =============================================================
|
| 4 |
+
|
| 5 |
+
from fastapi import FastAPI
|
| 6 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 7 |
+
import sys
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
# Fix Python module paths
|
| 11 |
+
current_dir = os.path.dirname(__file__)
|
| 12 |
+
sys.path.insert(0, current_dir)
|
| 13 |
+
|
| 14 |
+
from embeddings import embed_text
|
| 15 |
+
from database import insert_document_chunks, search_vectors
|
| 16 |
+
from models.rag import IngestRequest, SearchRequest
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
rag_app = FastAPI(title="RAG MCP Server")
|
| 20 |
+
|
| 21 |
+
# Enable CORS
|
| 22 |
+
rag_app.add_middleware(
|
| 23 |
+
CORSMiddleware,
|
| 24 |
+
allow_origins=["*"],
|
| 25 |
+
allow_credentials=True,
|
| 26 |
+
allow_methods=["*"],
|
| 27 |
+
allow_headers=["*"],
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# Wrapper functions to match expected interface
|
| 32 |
+
def db_insert(tenant_id: str, content: str, vector: list):
|
| 33 |
+
"""Wrapper for insert_document_chunks to match expected interface."""
|
| 34 |
+
return insert_document_chunks(tenant_id, content, vector)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def db_search(tenant_id: str, vector: list, limit: int = 5):
|
| 38 |
+
"""Wrapper for search_vectors to match expected interface."""
|
| 39 |
+
results = search_vectors(tenant_id, vector, limit)
|
| 40 |
+
return [{"text": text} for text in results]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@rag_app.post("/ingest")
|
| 44 |
+
async def ingest(req: IngestRequest):
|
| 45 |
+
vector = embed_text(req.content)
|
| 46 |
+
db_insert(req.tenant_id, req.content, vector)
|
| 47 |
+
return {"status": "ok"}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@rag_app.post("/search")
|
| 51 |
+
async def search(req: SearchRequest):
|
| 52 |
+
vector = embed_text(req.query)
|
| 53 |
+
results = db_search(req.tenant_id, vector)
|
| 54 |
+
return {"results": results}
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
if __name__ == "__main__":
|
| 58 |
+
import uvicorn
|
| 59 |
+
uvicorn.run(rag_app, host="0.0.0.0", port=8001)
|
| 60 |
+
|
backend/mcp_servers/web_server.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================
|
| 2 |
+
# File: backend/mcp_servers/web_server.py
|
| 3 |
+
# =============================================================
|
| 4 |
+
|
| 5 |
+
from fastapi import FastAPI
|
| 6 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 7 |
+
from duckduckgo_search import DDGS
|
| 8 |
+
import sys
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
# Fix Python module paths
|
| 12 |
+
current_dir = os.path.dirname(__file__)
|
| 13 |
+
sys.path.insert(0, current_dir)
|
| 14 |
+
|
| 15 |
+
from models.web import WebSearchRequest
|
| 16 |
+
|
| 17 |
+
web_app = FastAPI(title="Web Search MCP Server")
|
| 18 |
+
|
| 19 |
+
# Enable CORS
|
| 20 |
+
web_app.add_middleware(
|
| 21 |
+
CORSMiddleware,
|
| 22 |
+
allow_origins=["*"],
|
| 23 |
+
allow_credentials=True,
|
| 24 |
+
allow_methods=["*"],
|
| 25 |
+
allow_headers=["*"],
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@web_app.post("/search")
|
| 30 |
+
async def web_search(req: WebSearchRequest):
|
| 31 |
+
"""
|
| 32 |
+
Web search endpoint using DuckDuckGo.
|
| 33 |
+
Accepts tenant_id for multi-tenant support (currently not used but kept for API consistency).
|
| 34 |
+
Forces English language results by using region parameter and query modification.
|
| 35 |
+
"""
|
| 36 |
+
try:
|
| 37 |
+
ddg = DDGS()
|
| 38 |
+
|
| 39 |
+
# Modify query to prefer English results
|
| 40 |
+
# Add language hint to help get English content
|
| 41 |
+
query = req.query
|
| 42 |
+
# Only add language hint if not already present
|
| 43 |
+
if "lang:en" not in query.lower() and "site:en" not in query.lower():
|
| 44 |
+
query = f"{query} lang:en"
|
| 45 |
+
|
| 46 |
+
# Try to use region parameter for English results
|
| 47 |
+
# Common region codes: 'us-en' for US English, 'uk-en' for UK English
|
| 48 |
+
try:
|
| 49 |
+
results = ddg.text(query, max_results=5, region='us-en')
|
| 50 |
+
except (TypeError, KeyError):
|
| 51 |
+
# If region parameter not supported, try without it
|
| 52 |
+
# The lang:en in query should still help
|
| 53 |
+
results = ddg.text(query, max_results=5)
|
| 54 |
+
|
| 55 |
+
formatted = []
|
| 56 |
+
for r in results:
|
| 57 |
+
formatted.append({
|
| 58 |
+
"title": r.get("title"),
|
| 59 |
+
"snippet": r.get("body"),
|
| 60 |
+
"url": r.get("href"),
|
| 61 |
+
})
|
| 62 |
+
|
| 63 |
+
return {"results": formatted}
|
| 64 |
+
|
| 65 |
+
except Exception as e:
|
| 66 |
+
return {"error": str(e), "results": []}
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
if __name__ == "__main__":
|
| 70 |
+
import uvicorn
|
| 71 |
+
uvicorn.run(web_app, host="0.0.0.0", port=8002)
|
backend/tests/conftest.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
backend/tests/test_agent_orchestrator.py
CHANGED
|
@@ -1,3 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import sys
|
| 2 |
from pathlib import Path
|
| 3 |
|
|
@@ -5,17 +9,212 @@ from pathlib import Path
|
|
| 5 |
backend_dir = Path(__file__).parent.parent
|
| 6 |
sys.path.insert(0, str(backend_dir))
|
| 7 |
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
from api.services.agent_orchestrator import AgentOrchestrator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
|
|
|
|
|
|
|
| 12 |
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
)
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
-
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================
|
| 2 |
+
# File: tests/test_agent_orchestrator.py
|
| 3 |
+
# =============================================================
|
| 4 |
+
|
| 5 |
import sys
|
| 6 |
from pathlib import Path
|
| 7 |
|
|
|
|
| 9 |
backend_dir = Path(__file__).parent.parent
|
| 10 |
sys.path.insert(0, str(backend_dir))
|
| 11 |
|
| 12 |
+
try:
|
| 13 |
+
import pytest
|
| 14 |
+
HAS_PYTEST = True
|
| 15 |
+
except ImportError:
|
| 16 |
+
HAS_PYTEST = False
|
| 17 |
+
# Create a mock pytest decorator if pytest is not available
|
| 18 |
+
class MockMark:
|
| 19 |
+
def asyncio(self, func):
|
| 20 |
+
return func
|
| 21 |
+
class MockPytest:
|
| 22 |
+
mark = MockMark()
|
| 23 |
+
def fixture(self, func):
|
| 24 |
+
return func
|
| 25 |
+
pytest = MockPytest()
|
| 26 |
+
|
| 27 |
+
import os
|
| 28 |
from api.services.agent_orchestrator import AgentOrchestrator
|
| 29 |
+
from api.models.agent import AgentRequest, AgentDecision, AgentResponse
|
| 30 |
+
from api.models.redflag import RedFlagMatch
|
| 31 |
+
from api.services.llm_client import LLMClient
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# ---------------------------
|
| 35 |
+
# Mock classes
|
| 36 |
+
# ---------------------------
|
| 37 |
+
|
| 38 |
+
class FakeLLM(LLMClient):
|
| 39 |
+
def __init__(self, output="LLM_RESPONSE"):
|
| 40 |
+
self.output = output
|
| 41 |
+
|
| 42 |
+
async def simple_call(self, prompt: str, temperature: float = 0.0):
|
| 43 |
+
return self.output
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class FakeMCP:
|
| 47 |
+
"""Fake MCP server client used for rag/web/admin calls."""
|
| 48 |
+
def __init__(self):
|
| 49 |
+
self.last_rag = None
|
| 50 |
+
self.last_web = None
|
| 51 |
+
self.last_admin = None
|
| 52 |
+
|
| 53 |
+
async def call_rag(self, tenant_id: str, query: str):
|
| 54 |
+
self.last_rag = query
|
| 55 |
+
return {"results": [{"text": "RAG_DOC_CONTENT"}]}
|
| 56 |
+
|
| 57 |
+
async def call_web(self, tenant_id: str, query: str):
|
| 58 |
+
self.last_web = query
|
| 59 |
+
return {"results": [{"title": "WebResult", "snippet": "Fresh info"}]}
|
| 60 |
+
|
| 61 |
+
async def call_admin(self, tenant_id: str, query: str):
|
| 62 |
+
self.last_admin = query
|
| 63 |
+
return {"action": "allow"}
|
| 64 |
+
|
| 65 |
|
| 66 |
+
# ---------------------------
|
| 67 |
+
# Patch orchestrator to use fake MCP + fake redflag
|
| 68 |
+
# ---------------------------
|
| 69 |
|
| 70 |
+
@pytest.fixture
|
| 71 |
+
def orchestrator(monkeypatch):
|
| 72 |
+
|
| 73 |
+
# Fake LLM that always returns "MOCK_ANSWER"
|
| 74 |
+
llm = FakeLLM(output="MOCK_ANSWER")
|
| 75 |
+
|
| 76 |
+
fake_mcp = FakeMCP()
|
| 77 |
+
|
| 78 |
+
# Patch MCPClient
|
| 79 |
+
if HAS_PYTEST:
|
| 80 |
+
monkeypatch.setattr(
|
| 81 |
+
"api.services.agent_orchestrator.MCPClient",
|
| 82 |
+
lambda rag_url, web_url, admin_url: fake_mcp
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# Create orchestrator with fake URLs first
|
| 86 |
+
orch = AgentOrchestrator(
|
| 87 |
+
rag_mcp_url="fake_rag",
|
| 88 |
+
web_mcp_url="fake_web",
|
| 89 |
+
admin_mcp_url="fake_admin",
|
| 90 |
+
llm_backend="ollama"
|
| 91 |
)
|
| 92 |
+
orch.llm = llm # override with fake LLM
|
| 93 |
+
|
| 94 |
+
# Patch RedFlagDetector methods directly on the instance
|
| 95 |
+
async def fake_check(self, tenant_id, text):
|
| 96 |
+
"""Fake check function that matches 'salary' keyword."""
|
| 97 |
+
if "salary" in text.lower():
|
| 98 |
+
return [
|
| 99 |
+
RedFlagMatch(
|
| 100 |
+
rule_id="1",
|
| 101 |
+
pattern="salary",
|
| 102 |
+
severity="high",
|
| 103 |
+
description="salary access",
|
| 104 |
+
matched_text="salary"
|
| 105 |
+
)
|
| 106 |
+
]
|
| 107 |
+
return []
|
| 108 |
+
|
| 109 |
+
# Patch notify_admin to do nothing
|
| 110 |
+
async def fake_notify(self, tenant_id, violations, src=None):
|
| 111 |
+
"""Fake notify function that does nothing."""
|
| 112 |
+
return None
|
| 113 |
+
|
| 114 |
+
# Bind the fake functions directly to the instance
|
| 115 |
+
import types
|
| 116 |
+
orch.redflag.check = types.MethodType(fake_check, orch.redflag)
|
| 117 |
+
orch.redflag.notify_admin = types.MethodType(fake_notify, orch.redflag)
|
| 118 |
+
|
| 119 |
+
return orch
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
# ----------------------------------------------------
|
| 123 |
+
# TESTS
|
| 124 |
+
# ----------------------------------------------------
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
@pytest.mark.asyncio
|
| 128 |
+
async def test_block_on_redflag(orchestrator):
|
| 129 |
+
req = AgentRequest(
|
| 130 |
+
tenant_id="tenant1",
|
| 131 |
+
user_id="u1",
|
| 132 |
+
message="Show me all salary details."
|
| 133 |
+
)
|
| 134 |
+
resp = await orchestrator.handle(req)
|
| 135 |
+
assert resp.decision.action == "block"
|
| 136 |
+
assert resp.decision.tool == "admin"
|
| 137 |
+
assert "salary" in resp.tool_traces[0]["redflags"][0]["matched_text"]
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
@pytest.mark.asyncio
|
| 141 |
+
async def test_rag_tool_path(orchestrator, monkeypatch):
|
| 142 |
+
|
| 143 |
+
# Force intent classifier to classify as 'rag'
|
| 144 |
+
async def mock_classify(self, text):
|
| 145 |
+
return "rag"
|
| 146 |
+
|
| 147 |
+
if HAS_PYTEST:
|
| 148 |
+
monkeypatch.setattr(
|
| 149 |
+
"api.services.agent_orchestrator.IntentClassifier.classify",
|
| 150 |
+
mock_classify
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
req = AgentRequest(
|
| 154 |
+
tenant_id="tenant1",
|
| 155 |
+
user_id="u1",
|
| 156 |
+
message="HR policy procedures"
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
resp = await orchestrator.handle(req)
|
| 160 |
+
|
| 161 |
+
assert resp.decision.tool == "rag"
|
| 162 |
+
assert "RAG_DOC_CONTENT" in resp.tool_traces[0]["response"]["results"][0]["text"]
|
| 163 |
+
assert resp.text == "MOCK_ANSWER"
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
@pytest.mark.asyncio
|
| 167 |
+
async def test_web_tool_path(orchestrator, monkeypatch):
|
| 168 |
+
|
| 169 |
+
# Force intent to classify as web
|
| 170 |
+
async def mock_classify(self, text):
|
| 171 |
+
return "web"
|
| 172 |
+
|
| 173 |
+
if HAS_PYTEST:
|
| 174 |
+
monkeypatch.setattr(
|
| 175 |
+
"api.services.agent_orchestrator.IntentClassifier.classify",
|
| 176 |
+
mock_classify
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
req = AgentRequest(
|
| 180 |
+
tenant_id="tenant1",
|
| 181 |
+
user_id="u1",
|
| 182 |
+
message="latest stock price"
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
resp = await orchestrator.handle(req)
|
| 186 |
+
|
| 187 |
+
assert resp.decision.tool == "web"
|
| 188 |
+
assert resp.text == "MOCK_ANSWER"
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
@pytest.mark.asyncio
|
| 192 |
+
async def test_default_llm_path(orchestrator, monkeypatch):
|
| 193 |
+
|
| 194 |
+
# Force intent = general and force tool selector to NOT call any tool
|
| 195 |
+
async def mock_select(self, intent, text, context):
|
| 196 |
+
from api.models.agent import AgentDecision
|
| 197 |
+
return AgentDecision(
|
| 198 |
+
action="respond",
|
| 199 |
+
tool=None,
|
| 200 |
+
tool_input=None,
|
| 201 |
+
reason="forced_llm"
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
if HAS_PYTEST:
|
| 205 |
+
monkeypatch.setattr(
|
| 206 |
+
"api.services.agent_orchestrator.ToolSelector.select",
|
| 207 |
+
mock_select
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
req = AgentRequest(
|
| 211 |
+
tenant_id="tenant1",
|
| 212 |
+
user_id="u1",
|
| 213 |
+
message="just a normal question"
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
resp = await orchestrator.handle(req)
|
| 217 |
|
| 218 |
+
assert resp.decision.action == "respond"
|
| 219 |
+
assert resp.decision.tool is None
|
| 220 |
+
assert resp.text == "MOCK_ANSWER"
|
backend/tests/test_intent.py
CHANGED
|
@@ -1,3 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import sys
|
| 2 |
from pathlib import Path
|
| 3 |
|
|
@@ -5,43 +9,110 @@ from pathlib import Path
|
|
| 5 |
backend_dir = Path(__file__).parent.parent
|
| 6 |
sys.path.insert(0, str(backend_dir))
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
from api.services.intent_classifier import IntentClassifier
|
| 9 |
from api.services.llm_client import LLMClient
|
| 10 |
from api.services.redflag_detector import RedFlagDetector
|
| 11 |
from api.services.tool_selector import ToolSelector
|
| 12 |
-
from api.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
-
clf = IntentClassifier()
|
| 15 |
-
detector = RedFlagDetector()
|
| 16 |
-
selector = ToolSelector()
|
| 17 |
-
builder = PromptBuilder()
|
| 18 |
-
llm = LLMClient()
|
| 19 |
|
| 20 |
-
|
| 21 |
-
print("WEB:", clf.classify("latest news about ai"))
|
| 22 |
-
print("ADMIN:", clf.classify("delete all data"))
|
| 23 |
-
print("GENERAL:", clf.classify("hi how are you"))
|
| 24 |
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
-
|
| 33 |
-
print("rag:", selector.select_tool("rag", {"is_redflag": False}))
|
| 34 |
-
print("web:", selector.select_tool("web", {"is_redflag": False}))
|
| 35 |
-
print("none:", selector.select_tool("general", {"is_redflag": False}))
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
-
print(prompt)
|
| 46 |
|
| 47 |
-
|
|
|
|
|
|
| 1 |
+
# =============================================================
|
| 2 |
+
# File: tests/test_intent.py
|
| 3 |
+
# =============================================================
|
| 4 |
+
|
| 5 |
import sys
|
| 6 |
from pathlib import Path
|
| 7 |
|
|
|
|
| 9 |
backend_dir = Path(__file__).parent.parent
|
| 10 |
sys.path.insert(0, str(backend_dir))
|
| 11 |
|
| 12 |
+
try:
|
| 13 |
+
import pytest
|
| 14 |
+
HAS_PYTEST = True
|
| 15 |
+
except ImportError:
|
| 16 |
+
HAS_PYTEST = False
|
| 17 |
+
# Create a mock pytest decorator if pytest is not available
|
| 18 |
+
class MockMark:
|
| 19 |
+
def asyncio(self, func):
|
| 20 |
+
return func
|
| 21 |
+
class MockPytest:
|
| 22 |
+
mark = MockMark()
|
| 23 |
+
pytest = MockPytest()
|
| 24 |
+
|
| 25 |
+
import asyncio
|
| 26 |
from api.services.intent_classifier import IntentClassifier
|
| 27 |
from api.services.llm_client import LLMClient
|
| 28 |
from api.services.redflag_detector import RedFlagDetector
|
| 29 |
from api.services.tool_selector import ToolSelector
|
| 30 |
+
from api.models.redflag import RedFlagMatch
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@pytest.mark.asyncio
|
| 34 |
+
async def test_intent_rag_keywords():
|
| 35 |
+
classifier = IntentClassifier()
|
| 36 |
+
intent = await classifier.classify("Please check the HR policy document")
|
| 37 |
+
assert intent == "rag"
|
| 38 |
+
|
| 39 |
+
@pytest.mark.asyncio
|
| 40 |
+
async def test_intent_web_keywords():
|
| 41 |
+
classifier = IntentClassifier()
|
| 42 |
+
intent = await classifier.classify("latest news about Tesla stock")
|
| 43 |
+
assert intent == "web"
|
| 44 |
+
|
| 45 |
+
@pytest.mark.asyncio
|
| 46 |
+
async def test_intent_admin_keywords():
|
| 47 |
+
classifier = IntentClassifier()
|
| 48 |
+
intent = await classifier.classify("export all user data")
|
| 49 |
+
assert intent == "admin"
|
| 50 |
+
|
| 51 |
+
@pytest.mark.asyncio
|
| 52 |
+
async def test_intent_general():
|
| 53 |
+
classifier = IntentClassifier()
|
| 54 |
+
intent = await classifier.classify("explain how gravity works")
|
| 55 |
+
assert intent == "general"
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
+
# ---- LLM fallback test ----
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
+
class FakeLLM:
|
| 61 |
+
async def simple_call(self, prompt: str, temperature: float = 0.0):
|
| 62 |
+
return "web"
|
| 63 |
|
| 64 |
+
@pytest.mark.asyncio
|
| 65 |
+
async def test_intent_llm_fallback():
|
| 66 |
+
classifier = IntentClassifier(llm_client=FakeLLM())
|
| 67 |
+
intent = await classifier.classify("What's going on in the world?")
|
| 68 |
+
assert intent == "web"
|
| 69 |
|
| 70 |
|
| 71 |
+
# ---- Manual run function (for non-pytest execution) ----
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
+
async def run_manual_tests():
|
| 74 |
+
llm = LLMClient()
|
| 75 |
+
clf = IntentClassifier(llm_client=llm)
|
| 76 |
+
|
| 77 |
+
# Initialize detector with empty creds (will return empty results if no Supabase)
|
| 78 |
+
import os
|
| 79 |
+
detector = RedFlagDetector(
|
| 80 |
+
supabase_url=os.getenv("SUPABASE_URL") or "",
|
| 81 |
+
supabase_key=os.getenv("SUPABASE_SERVICE_KEY") or ""
|
| 82 |
+
)
|
| 83 |
+
selector = ToolSelector(llm_client=llm)
|
| 84 |
|
| 85 |
+
print("Intent Classification:")
|
| 86 |
+
print("RAG:", await clf.classify("summarize internal policy"))
|
| 87 |
+
print("WEB:", await clf.classify("latest news about ai"))
|
| 88 |
+
print("ADMIN:", await clf.classify("delete all data"))
|
| 89 |
+
print("GENERAL:", await clf.classify("hi how are you"))
|
| 90 |
+
|
| 91 |
+
print("\nRedFlag checks (will be empty if no Supabase configured):")
|
| 92 |
+
try:
|
| 93 |
+
print(await detector.check("tenant123", "My email is [email protected]"))
|
| 94 |
+
print(await detector.check("tenant123", "delete all data now"))
|
| 95 |
+
print(await detector.check("tenant123", "confidential salary report"))
|
| 96 |
+
print(await detector.check("tenant123", "hello world"))
|
| 97 |
+
except Exception as e:
|
| 98 |
+
print(f"RedFlag check failed (expected if Supabase not configured): {e}")
|
| 99 |
+
|
| 100 |
+
print("\nTool selection:")
|
| 101 |
+
print(await selector.select("admin", "delete all data", {}))
|
| 102 |
+
print(await selector.select("rag", "summarize policy", {}))
|
| 103 |
+
print(await selector.select("web", "latest news", {}))
|
| 104 |
+
print(await selector.select("general", "hello", {}))
|
| 105 |
+
|
| 106 |
+
print("\nLLM Test:")
|
| 107 |
+
try:
|
| 108 |
+
if llm.url and llm.model:
|
| 109 |
+
result = await llm.simple_call("Hello Llama!")
|
| 110 |
+
print(f"LLM Result: {result}")
|
| 111 |
+
else:
|
| 112 |
+
print("LLM not configured (OLLAMA_URL/OLLAMA_MODEL not set) - skipping LLM test")
|
| 113 |
+
except Exception as e:
|
| 114 |
+
print(f"LLM call failed (expected if Ollama not running or not configured): {e}")
|
| 115 |
|
|
|
|
| 116 |
|
| 117 |
+
if __name__ == "__main__":
|
| 118 |
+
asyncio.run(run_manual_tests())
|
env.example
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================
|
| 2 |
+
# IntegraChat Environment Variables Template
|
| 3 |
+
# =============================================================
|
| 4 |
+
# Copy this file to .env and fill in your actual values
|
| 5 |
+
|
| 6 |
+
# =============================================================
|
| 7 |
+
# SUPABASE CONFIGURATION
|
| 8 |
+
# =============================================================
|
| 9 |
+
SUPABASE_URL=https://your-project.supabase.co
|
| 10 |
+
SUPABASE_SERVICE_KEY=your_service_role_key_here
|
| 11 |
+
POSTGRESQL_URL=postgresql://user:password@host:port/database
|
| 12 |
+
|
| 13 |
+
# =============================================================
|
| 14 |
+
# LLM CONFIGURATION
|
| 15 |
+
# =============================================================
|
| 16 |
+
# If using local Ollama
|
| 17 |
+
OLLAMA_URL=http://localhost:11434
|
| 18 |
+
OLLAMA_MODEL=llama3.1:latest
|
| 19 |
+
|
| 20 |
+
# Backend selection (optional, defaults to "ollama")
|
| 21 |
+
LLM_BACKEND=ollama
|
| 22 |
+
|
| 23 |
+
# =============================================================
|
| 24 |
+
# MCP SERVER URLs
|
| 25 |
+
# =============================================================
|
| 26 |
+
RAG_MCP_URL=http://localhost:8001
|
| 27 |
+
WEB_MCP_URL=http://localhost:8002
|
| 28 |
+
ADMIN_MCP_URL=http://localhost:8003
|
| 29 |
+
|
| 30 |
+
# =============================================================
|
| 31 |
+
# BACKEND CONFIG
|
| 32 |
+
# =============================================================
|
| 33 |
+
APP_ENV=development
|
| 34 |
+
LOG_LEVEL=info
|
| 35 |
+
API_PORT=8000
|
| 36 |
+
|
| 37 |
+
# =============================================================
|
| 38 |
+
# OPTIONAL: ALERTING
|
| 39 |
+
# =============================================================
|
| 40 |
+
# ALERT_WEBHOOK=https://hooks.slack.com/services/your/webhook/url
|
| 41 |
+
|
pytest.ini
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[pytest]
|
| 2 |
+
asyncio_mode = auto
|
| 3 |
+
testpaths = backend/tests
|
| 4 |
+
python_files = test_*.py
|
| 5 |
+
python_classes = Test*
|
| 6 |
+
python_functions = test_*
|
| 7 |
+
|
requirements.txt
CHANGED
|
@@ -5,4 +5,7 @@ httpx
|
|
| 5 |
python-dotenv
|
| 6 |
psycopg2
|
| 7 |
supabase
|
| 8 |
-
sentence-transformers
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
python-dotenv
|
| 6 |
psycopg2
|
| 7 |
supabase
|
| 8 |
+
sentence-transformers
|
| 9 |
+
pytest
|
| 10 |
+
pytest-asyncio
|
| 11 |
+
duckduckgo-search
|