sanatan_ai / config.py
vikramvasudevan's picture
Upload folder using huggingface_hub
3c0fb3e verified
raw
history blame
5.74 kB
from metadata import MetadataWhereClause
from typing import List, Dict
from modules.config import scripture_configurations
class SanatanConfig:
dbStorePath: str = "./chromadb-store"
scriptures = scripture_configurations
def get_scripture_by_collection(self, collection_name: str):
return [
scripture
for scripture in self.scriptures
if scripture["collection_name"] == collection_name
][0]
def get_scripture_by_name(self, scripture_name: str):
return [
scripture
for scripture in self.scriptures
if scripture["name"] == scripture_name
][0]
def is_metadata_field_allowed(
self, collection_name: str, metadata_where_clause: MetadataWhereClause
):
scripture = self.get_scripture_by_collection(collection_name=collection_name)
allowed_fields = [field["name"] for field in scripture["metadata_fields"]]
def validate_clause(clause: MetadataWhereClause):
# validate direct filters
if clause.filters:
for f in clause.filters:
if f.metadata_field not in allowed_fields:
raise Exception(
f"metadata_field: [{f.metadata_field}] not allowed in collection [{collection_name}]. "
f"Here are the allowed fields with their descriptions: {scripture['metadata_fields']}"
)
# recurse into groups
if clause.groups:
for g in clause.groups:
validate_clause(g)
validate_clause(metadata_where_clause)
return True
def get_embedding_for_collection(self, collection_name: str):
scripture = self.get_scripture_by_collection(collection_name)
embedding_fn = "hf" # default is huggingface sentence transformaers
if "collection_embedding_fn" in scripture:
embedding_fn = scripture["collection_embedding_fn"] # overridden in config
return embedding_fn
def remove_callables(self, obj):
if isinstance(obj, dict):
return {
k: self.remove_callables(v) for k, v in obj.items() if not callable(v)
}
elif isinstance(obj, list):
return [self.remove_callables(v) for v in obj if not callable(v)]
else:
return obj
def filter_scriptures_fields(self, fields_to_keep: List[str]) -> List[Dict]:
"""
Return a list of scripture dicts containing only the specified fields.
"""
filtered = []
for s in self.scriptures:
filtered.append({k: s[k] for k in fields_to_keep if k in s})
return self.remove_callables(filtered)
def canonicalize_document(
self, scripture_name: str, document_text: str, metadata_doc: dict
):
"""
Convert scripture-specific document to a flattened canonical form.
Supports static strings or lambdas in field mapping.
Only allows keys from the allowed canonical fields list.
"""
allowed_keys = {
"_global_index",
"id",
"verse",
"text",
"title",
"unit",
"unit_index",
"word_by_word_native",
"translation",
"transliteration",
"reference_link",
"author",
"chapter_name",
"relative_path",
"location",
}
config = next((s for s in self.scriptures if s["name"] == scripture_name), None)
if not config:
raise ValueError(f"Unknown scripture: {scripture_name}")
mapping = config.get("field_mapping", {})
def resolve_field(field):
"""Resolve a field: string key or lambda"""
if callable(field):
try:
return field(metadata_doc)
except Exception:
return None
elif isinstance(field, str):
return metadata_doc.get(field)
return None
canonical_doc = {}
for key, field in mapping.items():
if key in allowed_keys: # only include allowed canonical keys
canonical_doc[key] = resolve_field(field)
# optionally add global fields from config
canonical_doc["scripture_name"] = config.get("name")
canonical_doc["scripture_title"] = config.get("title")
canonical_doc["source"] = config.get("source")
canonical_doc["language"] = config.get("language")
canonical_doc["unit"] = config.get("unit")
canonical_doc["document"] = document_text
if (
canonical_doc.get("text", "-") == "-"
or canonical_doc.get("text", None) is None
):
canonical_doc["text"] = canonical_doc["document"]
canonical_doc["document"] = "-"
verse = resolve_field(config.get("unit_field", config.get("unit")))
if verse == "-":
canonical_doc["verse"] = -1
else:
canonical_doc["verse"] = int(verse) if verse else 0
canonical_doc["id"] = resolve_field("id")
canonical_doc["_global_index"] = resolve_field("_global_index")
return canonical_doc
def get_collection_name(self, scripture_name):
config = next(
(s for s in SanatanConfig().scriptures if s["name"] == scripture_name), None
)
collection_name = config.get("collection_name")
return collection_name
if __name__ == "__main__":
print(SanatanConfig.scriptures)
[scripture["collection_name"] for scripture in SanatanConfig.scriptures]