import logging import os import pandas as pd from duckdb import DuckDBPyConnection from src.models import SQLQueryModel from src.prompts import SQL_PROMPT, USER_PROMPT logger = logging.getLogger(__name__) SQL_GENERATION_RETRIES = int(os.getenv("SQL_GENERATION_RETRIES", "5")) class SQLPipeline: def __init__( self, duckdb: DuckDBPyConnection, chain, ) -> None: self._duckdb = duckdb self.chain = chain def generate_sql( self, user_question: str, context: str, errors: str | None = None ) -> str | dict[str, str | int | float | None] | list[str] | None: """Generate SQL + description.""" user_prompt_formatted = USER_PROMPT.format( question=user_question, context=context ) if errors: user_prompt_formatted += f"Carefully review the previous error or\ exception and rewrite the SQL so that the error does not occur again.\ Try a different approach or rewrite SQL if needed. Last error: {errors}" sql = self.chain.run( system_prompt=SQL_PROMPT, user_prompt=user_prompt_formatted, format_name="sql_query", response_format=SQLQueryModel, ) logger.info(f"SQL Generated Successfully: {sql}") return sql def run_query(self, sql_query: str) -> pd.DataFrame | None: """Execute SQL and return dataframe.""" logger.info("Query Execution Started.") return self._duckdb.query(sql_query).df() def try_sql_with_retries( self, user_question: str, context: str, max_retries: int = SQL_GENERATION_RETRIES, ) -> tuple[ str | dict[str, str | int | float | None] | list[str] | None, pd.DataFrame | None, ]: """Try SQL generation + execution with retries.""" last_error = None all_errors = "" for attempt in range( 1, max_retries + 2 ): # @ Since the first is normal and not consider in retries try: if attempt > 1 and last_error: logger.info(f"Retrying: {attempt - 1}") # Generate SQL sql = self.generate_sql(user_question, context, errors=all_errors) if not sql: return None, None else: # Generate SQL sql = self.generate_sql(user_question, context) if not sql: return None, None # Try executing query sql_query_str = sql.get("sql_query") if isinstance(sql, dict) else sql if not isinstance(sql_query_str, str): raise ValueError( f"Expected SQL query to be a string, got {type(sql_query_str).__name__}" ) query_df = self.run_query(sql_query_str) # If execution succeeds, stop retrying or if df is not empty if query_df is not None and not query_df.empty: return sql, query_df except Exception as e: last_error = f"\nAttempt {attempt - 1}] {type(e).__name__}: {e}" logger.error(f"Error during SQL generation or execution: {last_error}") all_errors += last_error logger.error(f"Failed after {max_retries} attempts. Last error: {all_errors}") return None, None