Spaces:
Running
Running
| import logging | |
| import os | |
| import pandas as pd | |
| from duckdb import DuckDBPyConnection | |
| from src.models import SQLQueryModel | |
| from src.prompts import SQL_PROMPT, USER_PROMPT | |
| logger = logging.getLogger(__name__) | |
| SQL_GENERATION_RETRIES = int(os.getenv("SQL_GENERATION_RETRIES", "5")) | |
| class SQLPipeline: | |
| def __init__( | |
| self, | |
| duckdb: DuckDBPyConnection, | |
| chain, | |
| ) -> None: | |
| self._duckdb = duckdb | |
| self.chain = chain | |
| def generate_sql( | |
| self, user_question: str, context: str, errors: str | None = None | |
| ) -> str | dict[str, str | int | float | None] | list[str] | None: | |
| """Generate SQL + description.""" | |
| user_prompt_formatted = USER_PROMPT.format( | |
| question=user_question, context=context | |
| ) | |
| if errors: | |
| user_prompt_formatted += f"Carefully review the previous error or\ | |
| exception and rewrite the SQL so that the error does not occur again.\ | |
| Try a different approach or rewrite SQL if needed. Last error: {errors}" | |
| sql = self.chain.run( | |
| system_prompt=SQL_PROMPT, | |
| user_prompt=user_prompt_formatted, | |
| format_name="sql_query", | |
| response_format=SQLQueryModel, | |
| ) | |
| logger.info(f"SQL Generated Successfully: {sql}") | |
| return sql | |
| def run_query(self, sql_query: str) -> pd.DataFrame | None: | |
| """Execute SQL and return dataframe.""" | |
| logger.info("Query Execution Started.") | |
| return self._duckdb.query(sql_query).df() | |
| def try_sql_with_retries( | |
| self, | |
| user_question: str, | |
| context: str, | |
| max_retries: int = SQL_GENERATION_RETRIES, | |
| ) -> tuple[ | |
| str | dict[str, str | int | float | None] | list[str] | None, | |
| pd.DataFrame | None, | |
| ]: | |
| """Try SQL generation + execution with retries.""" | |
| last_error = None | |
| all_errors = "" | |
| for attempt in range( | |
| 1, max_retries + 2 | |
| ): # @ Since the first is normal and not consider in retries | |
| try: | |
| if attempt > 1 and last_error: | |
| logger.info(f"Retrying: {attempt - 1}") | |
| # Generate SQL | |
| sql = self.generate_sql(user_question, context, errors=all_errors) | |
| if not sql: | |
| return None, None | |
| else: | |
| # Generate SQL | |
| sql = self.generate_sql(user_question, context) | |
| if not sql: | |
| return None, None | |
| # Try executing query | |
| sql_query_str = sql.get("sql_query") if isinstance(sql, dict) else sql | |
| if not isinstance(sql_query_str, str): | |
| raise ValueError( | |
| f"Expected SQL query to be a string, got {type(sql_query_str).__name__}" | |
| ) | |
| query_df = self.run_query(sql_query_str) | |
| # If execution succeeds, stop retrying or if df is not empty | |
| if query_df is not None and not query_df.empty: | |
| return sql, query_df | |
| except Exception as e: | |
| last_error = f"\nAttempt {attempt - 1}] {type(e).__name__}: {e}" | |
| logger.error(f"Error during SQL generation or execution: {last_error}") | |
| all_errors += last_error | |
| logger.error(f"Failed after {max_retries} attempts. Last error: {all_errors}") | |
| return None, None | |