sanatan_ai / server.py
vikramvasudevan's picture
Upload folder using huggingface_hub
8d5ed9c verified
raw
history blame
17.7 kB
# server.py
import json
import random
import traceback
from typing import List, Optional
import uuid
from fastapi import APIRouter, Request, Query
from fastapi.responses import JSONResponse
import pycountry
from pydantic import BaseModel
from chat_utils import chat
from config import SanatanConfig
from db import SanatanDatabase
from metadata import MetadataWhereClause
from modules.audio.model import AudioRequest, AudioType
from modules.audio.service import svc_get_audio_urls, svc_get_indices_with_audio
from modules.config.categories import get_scripture_categories
from modules.quiz.answer_validator import validate_answer
from modules.quiz.models import Question
from modules.quiz.quiz_helper import generate_question
import logging
from modules.video.model import VideoRequest
from modules.video.service import svc_get_video_urls
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
router = APIRouter()
# In-memory mapping from session_id -> thread_id
# For production, you may want Redis or a DB for persistence
thread_map = {}
class Message(BaseModel):
language: str
text: str
session_id: str | None = None # Optional session ID from client
class QuizGeneratePayload(BaseModel):
language: Optional[str] = "English"
scripture: Optional[str] = None
complexity: Optional[str] = None
mode: Optional[str] = None
session_id: Optional[str] = None # Optional session ID from client
class QuizEvalPayload(BaseModel):
language: Optional[str] = "English"
q: Question
answer: str
session_id: Optional[str] = None # Optional session ID from client
LANG_NATIVE_NAMES = {
"en": "English",
"fr": "Français",
"es": "Español",
"hi": "हिन्दी",
"bn": "বাংলা",
"te": "తెలుగు",
"mr": "मराठी",
"ta": "தமிழ்",
"ur": "اردو",
"gu": "ગુજરાતી",
"kn": "ಕನ್ನಡ",
"ml": "മലയാളം",
"pa": "ਪੰਜਾਬੀ",
"as": "অসমীয়া",
"mai": "मैथिली",
"sd": "سنڌي",
"sat": "ᱥᱟᱱᱛᱟᱲᱤ",
}
@router.get("/languages")
async def handle_fetch_languages():
supported_lang_codes = [
"en",
"fr",
"es",
"hi",
"bn",
"te",
"mr",
"ta",
"ur",
"gu",
"kn",
"ml",
"pa",
"as",
"mai",
"sd",
"sat",
]
languages = []
for code in supported_lang_codes:
lang = pycountry.languages.get(alpha_2=code) or pycountry.languages.get(
alpha_3=code
)
if lang is None:
continue # skip unknown codes
english_name = lang.name
native_name = LANG_NATIVE_NAMES.get(code, english_name)
languages.append(
{
"code": code,
"name": english_name,
"native_name": native_name,
}
)
languages.sort(key=lambda x: x["name"])
return languages
@router.post("/greet")
async def handle_greet(msg: Message):
markdown = "Namaskaram 🙏 I am **bhashyam.ai** and I can help you explore the following scriptures:\n---\n"
for scripture in sorted(SanatanConfig().scriptures, key=lambda doc: doc["title"]):
num_units = SanatanDatabase().count(
collection_name=scripture["collection_name"]
)
markdown += f"- {scripture['title']} : `{num_units}` {scripture["unit"]}s\n"
session_id = msg.session_id
if not session_id:
session_id = str(uuid.uuid4())
return {"reply": markdown, "session_id": session_id}
@router.post("/chat")
async def handle_chat(msg: Message, request: Request):
try:
# Use existing session_id if provided, else generate new
session_id = msg.session_id
if not session_id:
session_id = str(uuid.uuid4())
print(session_id, ": user sent message : ", msg.text)
# Get or create a persistent thread_id for this session
if session_id not in thread_map:
thread_map[session_id] = str(uuid.uuid4())
thread_id = thread_map[session_id]
# Call your graph/chat function
reply_text = chat(
debug_mode=False,
message=msg.text,
history=None,
thread_id=thread_id,
preferred_language=msg.language or "English",
)
# Return both reply and session_id to the client
return {"reply": reply_text, "session_id": session_id}
except Exception as e:
traceback.print_exc()
return JSONResponse(status_code=500, content={"reply": f"Error: {e}"})
@router.post("/quiz/generate")
async def handle_quiz_generate(payload: QuizGeneratePayload, request: Request):
q = generate_question(
collection=payload.scripture
or random.choice(
[
s["collection_name"]
for s in SanatanConfig.scriptures
if s["collection_name"] != "yt_metadata"
]
),
complexity=payload.complexity
or random.choice(["beginner", "intermediate", "advanced"]),
mode=payload.mode or random.choice(["mcq", "open"]),
preferred_lamguage=payload.language or "English",
)
print(q.model_dump_json(indent=1))
return q.model_dump()
@router.post("/quiz/eval")
async def handle_quiz_eval(payload: QuizEvalPayload, request: Request):
result = validate_answer(
payload.q, payload.answer, preferred_language=payload.language or "English"
)
print(result.model_dump_json(indent=1))
return result
@router.get("/scriptures")
async def handle_get_scriptures():
return_values = {}
for scripture in SanatanConfig().scriptures:
if scripture["collection_name"] != "yt_metadata":
return_values[scripture["collection_name"]] = scripture["title"]
return return_values
class ScriptureRequest(BaseModel):
scripture_name: str
unit_index: int
@router.post("/scripture")
async def get_scripture(req: ScriptureRequest):
"""
Return a scripture unit (page or verse, based on config),
including all metadata fields separately.
used for page view to fetch by global index.
"""
logger.info("get_scripture: received request to fetch scripture: %s", req)
# find config entry for the scripture
config = next(
(s for s in SanatanConfig().scriptures if s["name"] == req.scripture_name), None
)
if not config:
return {"error": f"Scripture '{req.scripture_name}' not found"}
# fetch the raw document from DB
raw_doc = SanatanDatabase().fetch_document_by_index(
collection_name=config["collection_name"],
index=req.unit_index,
# unit_name=config.get("unit_field", config.get("unit")),
)
if not raw_doc or isinstance(raw_doc, str) or "error" in raw_doc:
return {"error": f"No data available for unit {req.unit_index}"}
# canonicalize it
canonical_doc = SanatanConfig().canonicalize_document(
scripture_name=req.scripture_name,
document_text=raw_doc.get("document", ""),
metadata_doc=raw_doc,
)
# add unit index & total units (so Flutter can paginate)
canonical_doc["total"] = SanatanDatabase().count(config["collection_name"])
# print("canonical_doc = ", canonical_doc)
return canonical_doc
@router.get("/scripture_configs")
async def get_scripture_configs():
scriptures = []
config = SanatanConfig()
for s in config.scriptures:
num_units = SanatanDatabase().count(collection_name=s["collection_name"])
# Deep copy metadata_fields so we don’t mutate the original config
metadata_fields = []
for f in s.get("metadata_fields", []):
f_copy = dict(f)
lov = f_copy.get("lov")
if callable(lov): # evaluate the function
try:
f_copy["lov"] = lov()
except Exception as e:
f_copy["lov"] = []
metadata_fields.append(f_copy)
scriptures.append(
{
"name": s["name"], # e.g. "bhagavad_gita"
"title": s["title"], # e.g. "Bhagavad Gita"
"category": s["category"], # e.g. "Philosophy"
"unit": s["unit"], # e.g. "verse" or "page"
"unit_field": s.get("unit_field", s.get("unit")),
"total": num_units,
"enabled": "field_mapping" in s,
"source": s.get("source", ""),
"credits": s.get(
"credits", {"art": [], "data": [], "audio": [], "video": []}
),
"metadata_fields": metadata_fields,
"field_mapping" : config.remove_callables(s.get("field_mapping",{}))
}
)
return {"scriptures": sorted(scriptures, key=lambda s: s["title"])}
class ScriptureFirstSearchRequst(BaseModel):
filter_obj: Optional[MetadataWhereClause] = None
has_audio: Optional[AudioType] = None
@router.post("/scripture/{scripture_name}/search")
async def search_scripture_find_first_match(
scripture_name: str,
req: ScriptureFirstSearchRequst,
):
"""
Search scripture collection and return the first matching result after applying audio filter.
"""
filter_obj = req.filter_obj
has_audio = req.has_audio
try:
logger.info(
"search_scripture_find_first_match: searching for %s with filters=%s | has_audio=%s",
scripture_name,
filter_obj,
has_audio,
)
db = SanatanDatabase()
config = next(
(s for s in SanatanConfig().scriptures if s["name"] == scripture_name),
None,
)
if not config:
return {"error": f"Scripture '{scripture_name}' not found"}
# 1️⃣ Fetch all matches
results = db.fetch_all_matches(
collection_name=config["collection_name"],
metadata_where_clause=filter_obj,
page=None, # Fetch all to apply audio filter
page_size=None,
)
formatted_results = []
for i in range(len(results["metadatas"])):
doc_id = results["ids"][i]
metadata_doc = results["metadatas"][i]
metadata_doc["id"] = doc_id
document_text = results["documents"][i] if results.get("documents") else None
canonical_doc = SanatanConfig().canonicalize_document(
scripture_name, document_text, metadata_doc
)
formatted_results.append(canonical_doc)
# 2️⃣ Apply has_audio filter
if has_audio and formatted_results:
if has_audio == AudioType.none:
all_audio_indices = set()
for atype in [
AudioType.recitation,
AudioType.virutham,
AudioType.upanyasam,
AudioType.santhai,
]:
indices = await svc_get_indices_with_audio(scripture_name, atype)
all_audio_indices.update(indices)
formatted_results = [
r for r in formatted_results if r["_global_index"] not in all_audio_indices
]
else:
audio_indices = set()
if has_audio == AudioType.any:
for atype in [
AudioType.recitation,
AudioType.virutham,
AudioType.upanyasam,
AudioType.santhai,
]:
indices = await svc_get_indices_with_audio(scripture_name, atype)
audio_indices.update(indices)
else:
audio_indices.update(await svc_get_indices_with_audio(scripture_name, has_audio))
formatted_results = [
r for r in formatted_results if r["_global_index"] in audio_indices
]
# 3️⃣ Sort by global index
formatted_results.sort(key=lambda x: x["_global_index"])
# 4️⃣ Return only the first valid result
return {"results": formatted_results[:1] if formatted_results else []}
except Exception as e:
logger.error("Error while searching %s", e, exc_info=True)
return {"error": str(e)}
class ScriptureMultiSearchRequest(BaseModel):
filter_obj: Optional[MetadataWhereClause] = None
page: int = 1
page_size: int = 20
has_audio: Optional[AudioType] = None
@router.post("/scripture/{scripture_name}/search/all")
async def search_scripture_find_all_matches(
scripture_name: str, req: ScriptureMultiSearchRequest
):
"""
Search scripture collection and return all matching results with pagination.
- `scripture_name`: Name of the collection
- `filter_obj`: MetadataWhereClause (filters, groups, operator)
- `page`: 1-based page number
- `page_size`: Number of results per page
- `has_audio` : optional. can take values any|none|recitation|virutham|upanyasam
"""
filter_obj = req.filter_obj
page = req.page
page_size = req.page_size
has_audio = req.has_audio
try:
logger.info(
"search_scripture_find_all_matches: searching for %s with filters %s | page=%s, page_size=%s, has_audio=%s",
scripture_name,
filter_obj,
page,
page_size,
has_audio,
)
db = SanatanDatabase()
config = next(
(s for s in SanatanConfig().scriptures if s["name"] == scripture_name),
None,
)
if not config:
return {"error": f"Scripture '{scripture_name}' not found"}
# 1️⃣ Fetch all matching metadata WITHOUT pagination yet
results = db.fetch_all_matches(
collection_name=config["collection_name"],
metadata_where_clause=filter_obj,
page=None, # Fetch all to apply audio filter
page_size=None,
)
formatted_results = []
all_indices = [] # Keep track of all _global_index
for i in range(len(results["metadatas"])):
doc_id = results["ids"][i]
metadata_doc = results["metadatas"][i]
metadata_doc["id"] = doc_id
document_text = (
results["documents"][i] if results.get("documents") else None
)
canonical_doc = SanatanConfig().canonicalize_document(
scripture_name, document_text, metadata_doc
)
formatted_results.append(canonical_doc)
all_indices.append(canonical_doc["_global_index"])
# 2️⃣ Apply has_audio filter
if has_audio:
if has_audio == AudioType.none:
# Fetch all indices that have any audio type
all_audio_indices = set()
for atype in [
AudioType.recitation,
AudioType.virutham,
AudioType.upanyasam,
AudioType.santhai,
]:
indices = await svc_get_indices_with_audio(scripture_name, atype)
all_audio_indices.update(indices)
# Keep only indices that are NOT in all_audio_indices
formatted_results = [
r
for r in formatted_results
if r["_global_index"] not in all_audio_indices
]
else:
if has_audio == AudioType.any:
# Combine indices for all audio types
audio_indices = set()
for atype in [
AudioType.recitation,
AudioType.virutham,
AudioType.upanyasam,
AudioType.santhai,
]:
indices = await svc_get_indices_with_audio(
scripture_name, atype
)
audio_indices.update(indices)
else:
audio_indices = set(
await svc_get_indices_with_audio(scripture_name, has_audio)
)
# Keep only indices that match
formatted_results = [
r for r in formatted_results if r["_global_index"] in audio_indices
]
# 3️⃣ Apply pagination on filtered results
total_matches = len(formatted_results)
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
paginated_results = formatted_results[start_idx:end_idx]
return {
"results": paginated_results,
"total_matches": total_matches,
"page": page,
"page_size": page_size,
}
except Exception as e:
logger.error("Error while searching %s", e, exc_info=True)
return {"error": str(e)}
@router.post("/audio")
async def generate_audio_urls(req: AudioRequest):
logger.info("generate_audio_urls: %s", req)
audio_urls = await svc_get_audio_urls(req)
return audio_urls
@router.post("/video")
async def generate_audio_urls(req: VideoRequest):
logger.info("generate_audio_urls: %s", req)
video_urls = await svc_get_video_urls(req)
return video_urls
@router.get("/scripture_categories")
def route_get_scripture_categories():
return get_scripture_categories()