datajoi-sql-agent / src /pipelines.py
Muhammad Mustehson
Update Old Code
a360e3c
import logging
import os
import pandas as pd
from duckdb import DuckDBPyConnection
from src.models import SQLQueryModel
from src.prompts import SQL_PROMPT, USER_PROMPT
logger = logging.getLogger(__name__)
SQL_GENERATION_RETRIES = int(os.getenv("SQL_GENERATION_RETRIES", "5"))
class SQLPipeline:
def __init__(
self,
duckdb: DuckDBPyConnection,
chain,
) -> None:
self._duckdb = duckdb
self.chain = chain
def generate_sql(
self, user_question: str, context: str, errors: str | None = None
) -> str | dict[str, str | int | float | None] | list[str] | None:
"""Generate SQL + description."""
user_prompt_formatted = USER_PROMPT.format(
question=user_question, context=context
)
if errors:
user_prompt_formatted += f"Carefully review the previous error or\
exception and rewrite the SQL so that the error does not occur again.\
Try a different approach or rewrite SQL if needed. Last error: {errors}"
sql = self.chain.run(
system_prompt=SQL_PROMPT,
user_prompt=user_prompt_formatted,
format_name="sql_query",
response_format=SQLQueryModel,
)
logger.info(f"SQL Generated Successfully: {sql}")
return sql
def run_query(self, sql_query: str) -> pd.DataFrame | None:
"""Execute SQL and return dataframe."""
logger.info("Query Execution Started.")
return self._duckdb.query(sql_query).df()
def try_sql_with_retries(
self,
user_question: str,
context: str,
max_retries: int = SQL_GENERATION_RETRIES,
) -> tuple[
str | dict[str, str | int | float | None] | list[str] | None,
pd.DataFrame | None,
]:
"""Try SQL generation + execution with retries."""
last_error = None
all_errors = ""
for attempt in range(
1, max_retries + 2
): # @ Since the first is normal and not consider in retries
try:
if attempt > 1 and last_error:
logger.info(f"Retrying: {attempt - 1}")
# Generate SQL
sql = self.generate_sql(user_question, context, errors=all_errors)
if not sql:
return None, None
else:
# Generate SQL
sql = self.generate_sql(user_question, context)
if not sql:
return None, None
# Try executing query
sql_query_str = sql.get("sql_query") if isinstance(sql, dict) else sql
if not isinstance(sql_query_str, str):
raise ValueError(
f"Expected SQL query to be a string, got {type(sql_query_str).__name__}"
)
query_df = self.run_query(sql_query_str)
# If execution succeeds, stop retrying or if df is not empty
if query_df is not None and not query_df.empty:
return sql, query_df
except Exception as e:
last_error = f"\nAttempt {attempt - 1}] {type(e).__name__}: {e}"
logger.error(f"Error during SQL generation or execution: {last_error}")
all_errors += last_error
logger.error(f"Failed after {max_retries} attempts. Last error: {all_errors}")
return None, None