import logging import asyncio import json import ast from typing import List, Dict, Any, Union from dotenv import load_dotenv # LangChain imports from langchain_openai import ChatOpenAI from langchain_anthropic import ChatAnthropic from langchain_cohere import ChatCohere from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint from langchain_core.messages import SystemMessage, HumanMessage # Local imports from .utils import getconfig, get_auth # --------------------------------------------------------------------- # Model / client initialization (non exaustive list of providers) # --------------------------------------------------------------------- config = getconfig("params.cfg") PROVIDER = config.get("generator", "PROVIDER") MODEL = config.get("generator", "MODEL") MAX_TOKENS = int(config.get("generator", "MAX_TOKENS")) TEMPERATURE = float(config.get("generator", "TEMPERATURE")) INFERENCE_PROVIDER = config.get("generator", "INFERENCE_PROVIDER") ORGANIZATION = config.get("generator", "ORGANIZATION") # Set up authentication for the selected provider auth_config = get_auth(PROVIDER) def get_chat_model(): """Initialize the appropriate LangChain chat model based on provider""" common_params = { "temperature": TEMPERATURE, "max_tokens": MAX_TOKENS, } if PROVIDER == "openai": return ChatOpenAI( model=MODEL, openai_api_key=auth_config["api_key"], **common_params ) elif PROVIDER == "anthropic": return ChatAnthropic( model=MODEL, anthropic_api_key=auth_config["api_key"], **common_params ) elif PROVIDER == "cohere": return ChatCohere( model=MODEL, cohere_api_key=auth_config["api_key"], **common_params ) elif PROVIDER == "huggingface": # Initialize HuggingFaceEndpoint with explicit parameters llm = HuggingFaceEndpoint( repo_id=MODEL, huggingfacehub_api_token=auth_config["api_key"], task="text-generation", provider=INFERENCE_PROVIDER, server_kwargs={"bill_to": ORGANIZATION}, temperature=TEMPERATURE, max_new_tokens=MAX_TOKENS ) return ChatHuggingFace(llm=llm) else: raise ValueError(f"Unsupported provider: {PROVIDER}") # Initialize provider-agnostic chat model chat_model = get_chat_model() # --------------------------------------------------------------------- # Context processing - may need further refinement (i.e. to manage other data sources) # --------------------------------------------------------------------- def extract_relevant_fields(retrieval_results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Extract only relevant fields from retrieval results. Args: retrieval_results: List of JSON objects from retriever Returns: List of processed objects with only relevant fields """ retrieval_results = ast.literal_eval(retrieval_results) processed_results = [] for result in retrieval_results: # Extract the answer content answer = result.get('answer', '') # Extract document identification from metadata metadata = result.get('answer_metadata', {}) doc_info = { 'answer': answer, 'filename': metadata.get('filename', 'Unknown'), 'page': metadata.get('page', 'Unknown'), 'year': metadata.get('year', 'Unknown'), 'source': metadata.get('source', 'Unknown'), 'document_id': metadata.get('_id', 'Unknown') } processed_results.append(doc_info) return processed_results def format_context_from_results(processed_results: List[Dict[str, Any]]) -> str: """ Format processed retrieval results into a context string for the LLM. Args: processed_results: List of processed objects with relevant fields Returns: Formatted context string """ if not processed_results: return "" context_parts = [] for i, result in enumerate(processed_results, 1): doc_reference = f"[Document {i}: {result['filename']}" if result['page'] != 'Unknown': doc_reference += f", Page {result['page']}" if result['year'] != 'Unknown': doc_reference += f", Year {result['year']}" doc_reference += "]" context_part = f"{doc_reference}\n{result['answer']}\n" context_parts.append(context_part) return "\n".join(context_parts) # --------------------------------------------------------------------- # Core generation function for both Gradio UI and MCP # --------------------------------------------------------------------- async def _call_llm(messages: list) -> str: """ Provider-agnostic LLM call using LangChain. Args: messages: List of LangChain message objects Returns: Generated response content as string """ try: # Use async invoke for better performance response = await chat_model.ainvoke(messages) return response.content.strip() except Exception as e: logging.exception(f"LLM generation failed with provider '{PROVIDER}' and model '{MODEL}': {e}") raise def build_messages(question: str, context: str) -> list: """ Build messages in LangChain format. Args: question: The user's question context: The relevant context for answering Returns: List of LangChain message objects """ system_content = ( "You are an expert assistant. Answer the USER question using only the " "CONTEXT provided. If the context is insufficient say 'I don't know.'" ) user_content = f"### CONTEXT\n{context}\n\n### USER QUESTION\n{question}" return [ SystemMessage(content=system_content), HumanMessage(content=user_content) ] async def generate(query: str, context: Union[str, List[Dict[str, Any]]]) -> str: """ Generate an answer to a query using provided context through RAG. This function takes a user query and relevant context, then uses a language model to generate a comprehensive answer based on the provided information. Args: query (str): User query context (list): List of retrieval result objects (dictionaries) Returns: str: The generated answer based on the query and context """ if not query.strip(): return "Error: Query cannot be empty" # Handle both string context (for Gradio UI) and list context (from retriever) if isinstance(context, list): if not context: return "Error: No retrieval results provided" # Process the retrieval results processed_results = extract_relevant_fields(context) formatted_context = format_context_from_results(processed_results) if not formatted_context.strip(): return "Error: No valid content found in retrieval results" elif isinstance(context, str): if not context.strip(): return "Error: Context cannot be empty" formatted_context = context else: return "Error: Context must be either a string or list of retrieval results" try: messages = build_messages(query, formatted_context) answer = await _call_llm(messages) return answer except Exception as e: logging.exception("Generation failed") return f"Error: {str(e)}"