Cheh Kit Hong commited on
Commit
aa018e3
Β·
1 Parent(s): 0fc97a4

fixing gradio

Browse files
README.md CHANGED
@@ -2,14 +2,12 @@ rag_agent/
2
  β”œβ”€β”€ app.py # Main Gradio application entry point
3
  β”œβ”€β”€ config.py # Configuration hub (models, chunk sizes, providers)
4
  β”œβ”€β”€ util.py # PDF to markdown conversion
5
- β”œβ”€β”€ document_chunker.py # Chunking strategy
6
  β”œβ”€β”€ core/ # Core RAG components orchestration
7
  β”‚ β”œβ”€β”€ chat_interface.py
8
  β”‚ β”œβ”€β”€ document_manager.py
9
  β”‚ └── rag_system.py
10
- β”œβ”€β”€ knowledge_base/ # Storage management
11
- β”‚ β”œβ”€β”€ chroma.py # Parent chunks storage (JSON)
12
- β”‚ └── vector_db_manager.py
13
  β”œβ”€β”€ agent_logic/ # LangGraph agent workflow
14
  β”‚ β”œβ”€β”€ edges.py # Conditional routing logic
15
  β”‚ β”œβ”€β”€ graph.py # Graph construction and compilation
 
2
  β”œβ”€β”€ app.py # Main Gradio application entry point
3
  β”œβ”€β”€ config.py # Configuration hub (models, chunk sizes, providers)
4
  β”œβ”€β”€ util.py # PDF to markdown conversion
 
5
  β”œβ”€β”€ core/ # Core RAG components orchestration
6
  β”‚ β”œβ”€β”€ chat_interface.py
7
  β”‚ β”œβ”€β”€ document_manager.py
8
  β”‚ └── rag_system.py
9
+ β”œβ”€β”€ knowledge_base/ # for create chromadb
10
+ β”œβ”€β”€ chroma_data/ # chroma vectorstore data
 
11
  β”œβ”€β”€ agent_logic/ # LangGraph agent workflow
12
  β”‚ β”œβ”€β”€ edges.py # Conditional routing logic
13
  β”‚ β”œβ”€β”€ graph.py # Graph construction and compilation
agent/graph.py CHANGED
@@ -13,12 +13,14 @@ def create_agent_graph(llm, vectordb, search_tools) -> StateGraph:
13
  graph = StateGraph(AgentState)
14
  checkpointer = MemorySaver()
15
 
 
16
  web_search_tool_node = ToolNode(search_tools)
17
 
18
  # --- Nodes ---
19
  graph.add_node("router_node", partial(router_node, llm=llm))
20
  graph.add_node("vectordb_node", partial(vectordb_node, vectorstore=vectordb))
21
- graph.add_node("web_search_node", web_search_tool_node)
 
22
  graph.add_node("generate_node", partial(generate_node, llm=llm))
23
 
24
  # --- Edges ---
@@ -28,16 +30,23 @@ def create_agent_graph(llm, vectordb, search_tools) -> StateGraph:
28
  "router_node",
29
  routing_logic,
30
  {
31
- # Output from routing_logic -> Target Node Name
32
  "vectordb_node": "vectordb_node",
33
- "web_search_node": "web_search_node",
34
  "generate_node": "generate_node",
35
- # If your logic has an 'else' that returns END, you don't list it here.
 
 
 
 
 
 
 
 
36
  }
37
  )
38
 
39
  graph.add_edge("vectordb_node", "generate_node")
40
- graph.add_edge("web_search_node", "generate_node")
41
 
42
  graph.add_edge("generate_node", END)
43
 
 
13
  graph = StateGraph(AgentState)
14
  checkpointer = MemorySaver()
15
 
16
+ llm_with_tools = llm.bind_tools(search_tools)
17
  web_search_tool_node = ToolNode(search_tools)
18
 
19
  # --- Nodes ---
20
  graph.add_node("router_node", partial(router_node, llm=llm))
21
  graph.add_node("vectordb_node", partial(vectordb_node, vectorstore=vectordb))
22
+ graph.add_node("web_search_agent_node", partial(web_search_agent_node, llm=llm_with_tools))
23
+ graph.add_node("web_search_tool_node", web_search_tool_node)
24
  graph.add_node("generate_node", partial(generate_node, llm=llm))
25
 
26
  # --- Edges ---
 
30
  "router_node",
31
  routing_logic,
32
  {
 
33
  "vectordb_node": "vectordb_node",
34
+ "web_search_agent_node": "web_search_agent_node",
35
  "generate_node": "generate_node",
36
+ }
37
+ )
38
+
39
+ graph.add_conditional_edges(
40
+ "web_search_agent_node",
41
+ tools_condition,
42
+ {
43
+ "tools": "web_search_tool_node", # Changed key from node name to "tools"
44
+ "__end__": "generate_node", # Changed key from "generate_node" to "__end__"
45
  }
46
  )
47
 
48
  graph.add_edge("vectordb_node", "generate_node")
49
+ graph.add_edge("web_search_tool_node", "generate_node")
50
 
51
  graph.add_edge("generate_node", END)
52
 
agent/more_nodes.py DELETED
@@ -1,97 +0,0 @@
1
- from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, RemoveMessage
2
- from typing import Literal
3
-
4
- from .state import AgentState, QueryAnalysis
5
- from .prompts import *
6
-
7
- def analyze_chat_and_summarize(state: AgentState, llm):
8
- """
9
- Analyzes chat history and summarizes key points for context.
10
- """
11
- if len(state["messages"]) < 4: # Need some history to summarize
12
- return {"conversation_summary": ""}
13
-
14
- # Extract relevant messages (excluding current query and system messages)
15
- relevant_msgs = [
16
- msg for msg in state["messages"][:-1] # Exclude current query
17
- if isinstance(msg, (HumanMessage, AIMessage))
18
- and not getattr(msg, "tool_calls", None)
19
- ]
20
-
21
- if not relevant_msgs:
22
- return {"conversation_summary": ""}
23
-
24
- summary_prompt = """**Summarize the key topics and context from this conversation concisely (1-2 sentences max).**
25
- Discard irrelevant information, such as misunderstandings or off-topic queries/responses.
26
- If there are no key topics, return an empty string.
27
-
28
- """
29
- for msg in relevant_msgs[-6:]: # Last 6 messages for context
30
- role = "User" if isinstance(msg, HumanMessage) else "Assistant"
31
- summary_prompt += f"{role}: {msg.content}\n"
32
-
33
- summary_prompt += "\nBrief Summary:"
34
- summary_response = llm.with_config(temperature=0.3).invoke([SystemMessage(content=summary_prompt)])
35
- return {"conversation_summary": summary_response.content}
36
-
37
- def analyze_and_rewrite_query(state: AgentState, llm):
38
- """
39
- Analyzes user query and rewrites it for clarity, optionally using conversation context.
40
- """
41
- last_message = state["messages"][-1]
42
- conversation_summary = state.get("conversation_summary", "")
43
-
44
- context_section = (
45
- f"**Conversation Context:**\n{conversation_summary}"
46
- if conversation_summary.strip()
47
- else "**Conversation Context:**\n[First query in conversation]"
48
- )
49
-
50
- # Create analysis prompt
51
- query_analysis_prompt = get_query_analysis_prompt(last_message.content, conversation_summary)
52
-
53
- llm_with_structure = llm.with_config(temperature=0.3).with_structured_output(QueryAnalysis)
54
- response = llm_with_structure.invoke([SystemMessage(content=query_analysis_prompt)])
55
-
56
- if response.is_clear:
57
- # Remove all non-system messages
58
- delete_all = [
59
- RemoveMessage(id=m.id)
60
- for m in state["messages"]
61
- if not isinstance(m, SystemMessage)
62
- ]
63
-
64
- # Format rewritten query
65
- rewritten = (
66
- "\n".join([f"{i+1}. {q}" for i, q in enumerate(response.questions)])
67
- if len(response.questions) > 1
68
- else response.questions[0]
69
- )
70
- return {
71
- "questionIsClear": True,
72
- "messages": delete_all + [HumanMessage(content=rewritten)]
73
- }
74
- else:
75
- clarification = response.clarification_needed or "I need more information to understand your question."
76
- return {
77
- "questionIsClear": False,
78
- "messages": [AIMessage(content=clarification)]
79
- }
80
-
81
- def human_input_node(state: AgentState):
82
- """Placeholder node for human-in-the-loop interruption"""
83
- return {}
84
-
85
- def route_after_rewrite(state: AgentState) -> Literal["agent", "human_input"]:
86
- """Route to agent if question is clear, otherwise wait for human input"""
87
- return "agent" if state.get("questionIsClear", False) else "human_input"
88
-
89
- def agent_node(state: AgentState, llm_with_tools):
90
- """Main agent node that processes queries using tools"""
91
- system_prompt = get_system_prompt()
92
- messages = [system_prompt] + state["messages"]
93
- response = llm_with_tools.invoke(messages)
94
- return {"messages": [response]}
95
-
96
- if __name__ == "__main__":
97
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
agent/nodes.py CHANGED
@@ -1,4 +1,4 @@
1
- from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, RemoveMessage
2
  from typing import Literal
3
  from langgraph.graph import START, END
4
 
@@ -13,17 +13,17 @@ def router_node(state: AgentState, llm):
13
  """
14
  query = state["messages"][-1].content
15
  rag_method_prompt = determine_rag_method_prompt()
16
- rag_method_result = llm.invoke([rag_method_prompt, HumanMessage(content=query)])
17
  rag_method = rag_method_result.content.strip().upper()
18
  state["rag_method"] = rag_method
19
  return state
20
 
21
- def routing_logic(self, state: AgentState) -> str:
22
  rag_method = state["rag_method"]
23
  if rag_method == "RAG":
24
  return "vectordb_node"
25
  elif rag_method == "WEBSEARCH":
26
- return "web_search_node"
27
  elif rag_method == "GENERAL":
28
  return "generate_node" # fallback to generate_node if the question do not requires RAG or websearch
29
  else:
@@ -31,7 +31,7 @@ def routing_logic(self, state: AgentState) -> str:
31
  print(f"ERROR: Router returned unclassified intent: {rag_method}. Terminating flow.")
32
  return END
33
 
34
- def vectordb_node(state: AgentState, llm, vectorstore):
35
  """
36
  Use vectordb to answer the query.
37
  """
@@ -43,12 +43,39 @@ def vectordb_node(state: AgentState, llm, vectorstore):
43
  state["context"] = context
44
  return state
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def generate_node(state: AgentState, llm):
47
  messages = state["messages"][-10:] # Limit to last 10 messages to handle token limit
48
  context = state.get("context", [])
49
 
50
  system_content = get_system_prompt()
51
 
 
 
 
 
 
 
 
 
52
  if context:
53
  system_content += f"\n\nRelevant Context:\n{context}"
54
 
 
1
+ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage, RemoveMessage
2
  from typing import Literal
3
  from langgraph.graph import START, END
4
 
 
13
  """
14
  query = state["messages"][-1].content
15
  rag_method_prompt = determine_rag_method_prompt()
16
+ rag_method_result = llm.invoke([SystemMessage(content=rag_method_prompt), HumanMessage(content=query)])
17
  rag_method = rag_method_result.content.strip().upper()
18
  state["rag_method"] = rag_method
19
  return state
20
 
21
+ def routing_logic(state: AgentState) -> str:
22
  rag_method = state["rag_method"]
23
  if rag_method == "RAG":
24
  return "vectordb_node"
25
  elif rag_method == "WEBSEARCH":
26
+ return "web_search_agent_node"
27
  elif rag_method == "GENERAL":
28
  return "generate_node" # fallback to generate_node if the question do not requires RAG or websearch
29
  else:
 
31
  print(f"ERROR: Router returned unclassified intent: {rag_method}. Terminating flow.")
32
  return END
33
 
34
+ def vectordb_node(state: AgentState, vectorstore):
35
  """
36
  Use vectordb to answer the query.
37
  """
 
43
  state["context"] = context
44
  return state
45
 
46
+ def web_search_agent_node(state: AgentState, llm):
47
+ """
48
+ LLM agent that decides which web search tools to call.
49
+ This generates an AIMessage with tool_calls.
50
+ """
51
+ messages = state["messages"]
52
+
53
+ # Add instruction to use tools
54
+ system_msg = SystemMessage(content="""You are a web search assistant.
55
+ Use the available search tools (web_search_tavily, wikipedia_search) to find information about the user's query.
56
+ Call the appropriate tool with the query.""")
57
+
58
+ messages_with_system = [system_msg] + messages
59
+
60
+ # LLM with tools bound will generate AIMessage with tool_calls
61
+ response = llm.invoke(messages_with_system)
62
+
63
+ return {"messages": [response]}
64
+
65
  def generate_node(state: AgentState, llm):
66
  messages = state["messages"][-10:] # Limit to last 10 messages to handle token limit
67
  context = state.get("context", [])
68
 
69
  system_content = get_system_prompt()
70
 
71
+ # Extract web search results from ToolMessages if available
72
+ if not context:
73
+ for msg in reversed(messages):
74
+ if isinstance(msg, ToolMessage):
75
+ # Web search results come as ToolMessage content
76
+ if msg.content:
77
+ context += f"\n\n{msg.content}"
78
+
79
  if context:
80
  system_content += f"\n\nRelevant Context:\n{context}"
81
 
agent/prompts.py CHANGED
@@ -2,87 +2,31 @@ from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
2
 
3
  def get_system_prompt() -> SystemMessage:
4
  """Generate the system prompt for the RAG agent."""
5
- return SystemMessage(content="""
6
- You are an intelligent assistant that MUST use the available tools to answer questions.
7
-
8
- **MANDATORY WORKFLOW β€” Follow these steps for EVERY question:**
9
- 1. **Call `search_chroma`** with the user's query (K = 3–7) to find the most relevant chunks in the Chroma vector store.
10
- 2. **Review the retrieved chunks** and identify the relevant ones. The chunks will contain content and metadata (such as `parent_id` and `source`).
11
- 3. **If additional context is needed**, retrieve more details from the source tools (e.g., Wikipedia or Arxiv) to provide the full answer.
12
- 4. **Use metadata** such as `source` and `parent_id` to help clarify or support the answer when applicable.
13
- 5. **Answer using ONLY the retrieved information**:
14
- - Combine relevant chunks and use metadata (e.g., citation sources) as needed to clarify or support the response.
15
- 6. **If no relevant information is found**, rewrite the query into an **answer-focused declarative statement** and search again **only once** using `search_chroma`.
16
- 7. **Return the final answer** derived from the most relevant results.
17
- """)
18
 
19
  def determine_rag_method_prompt() -> str:
20
- return SystemMessage(content="""
21
- You are an rag method classification model. Given the user's query, you must classify the method to use
22
  as one and only one of the following options:
23
 
24
- 1. **RAG**: The query likely relates to the internal, domain-specific documents you have access to.
25
- 2. **WEBSEARCH**: The query requires real-time facts, general knowledge, or external information not in your documents.
26
- 3. **GENERAL**: The query can be answered based on your existing knowledge without external resources.
27
-
28
- Respond STRICTLY with only one of these words: RAG, WEBSEARCH, or GENERAL. Do not include any punctuation, explanation, or extra text.
29
- """
30
- )
31
-
32
- def get_conversation_summary_prompt(messages):
33
- """Generate a prompt for conversation summarization."""
34
- summary_prompt = """**Summarize the key topics and context from this conversation concisely (1-2 sentences max).**
35
- Discard irrelevant information, such as misunderstandings or off-topic queries/responses.
36
- If there are no key topics, return an empty string.
37
-
38
- """
39
-
40
- for msg in messages[-6:]:
41
- role = "User" if isinstance(msg, HumanMessage) else "Assistant"
42
- summary_prompt += f"{role}: {msg.content}\n"
43
-
44
- summary_prompt += "\n**Brief Summary:**"
45
- return summary_prompt
46
-
47
- def get_query_analysis_prompt(query: str, conversation_summary: str = "") -> str:
48
- """Generate a prompt for query analysis and rewriting."""
49
- context_section = (
50
- f"**Conversation Context:**\n{conversation_summary}"
51
- if conversation_summary.strip()
52
- else "**Conversation Context:**\n[First query in conversation]"
53
- )
54
-
55
- return f"""
56
- **Rewrite the user's query** to be clear, self-contained, and optimized for information retrieval.
57
-
58
- **User Query:**
59
- "{query}"
60
-
61
- {context_section}
62
-
63
- **Instructions:**
64
-
65
- 1. **Resolve references for follow-ups:**
66
- - If the query uses pronouns or refers to previous topics, use the context to make it self-contained.
67
-
68
- 2. **Ensure clarity for new queries:**
69
- - Make the query specific, concise, and unambiguous.
70
-
71
- 3. **Correct errors and interpret intent:**
72
- - If the query is grammatically incorrect, contains typos, or has abbreviations, correct it and infer the intended meaning.
73
-
74
- 4. **Split only when necessary:**
75
- - If multiple distinct questions exist, split into **up to 3 focused sub-queries** to avoid over-segmentation.
76
- - Each sub-query must still be meaningful on its own.
77
 
78
- 5. **Optimize for search:**
79
- - Use **keywords, proper nouns, numbers, dates, and technical terms**.
80
- - Remove conversational filler, vague words, and redundancies.
81
- - Make the query concise and focused for information retrieval.
 
82
 
83
- 6. **Mark as unclear if intent is missing:**
84
- - This includes nonsense, gibberish, insults, or statements without an apparent question.
85
  """
86
-
87
  if __name__ == "__main__":
88
  pass
 
2
 
3
  def get_system_prompt() -> SystemMessage:
4
  """Generate the system prompt for the RAG agent."""
5
+ return """
6
+ You are a helpful assistant tasked with answering questions using a set of tools.
7
+ Follow the ReAct framework: iteratively reason through the problem step-by-step, use tools when necessary, and refine your approach based on tool outputs.
8
+ You will be provided with relevant context from the knowledge base if required. Use this context to inform your response, but feel free to supplement with your own knowledge when appropriate. Context will be provided in the state under 'context' key.
9
+ You will also have access to web search tools like Tavily, Wikipedia or Arxiv.
10
+ DO NOT make any assumptions.
11
+ """
 
 
 
 
 
 
12
 
13
  def determine_rag_method_prompt() -> str:
14
+ return """
15
+ You are a query classification model. Given the user's query, you must classify the method to use
16
  as one and only one of the following options:
17
 
18
+ 1. **RAG**: The query asks about specific documents, papers, or systems like DeepAnalyze, AgentMem, SAM3, SAM 3, SAM3D, DeepSeek-OCR, or any technical architecture/implementation details from research papers.
19
+ 2. **WEBSEARCH**: The query asks for current events, latest news, real-time information after January 2024, or general factual knowledge not in specialized documents.
20
+ 3. **GENERAL**: The query is a simple calculation, definition, reasoning task, or common knowledge question that doesn't need external data.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ **Examples:**
23
+ - "What is DeepAnalyze?" β†’ RAG
24
+ - "Explain SAM 3 architecture" β†’ RAG
25
+ - "Latest AI news in 2025" β†’ WEBSEARCH
26
+ - "What is 15 times 7?" β†’ GENERAL
27
 
28
+ Respond STRICTLY with only one word: RAG, WEBSEARCH, or GENERAL. No punctuation or extra text.
 
29
  """
30
+
31
  if __name__ == "__main__":
32
  pass
agent/state.py CHANGED
@@ -18,7 +18,7 @@ class AgentState(TypedDict):
18
  conversation_summary: str = ""
19
 
20
 
21
-
22
  class QueryAnalysis(BaseModel):
23
  """Structured output for query analysis"""
24
  is_clear: bool = Field(description="Indicates if the user's question is clear and answerable")
 
18
  conversation_summary: str = ""
19
 
20
 
21
+ # Implement later if needed, omit first
22
  class QueryAnalysis(BaseModel):
23
  """Structured output for query analysis"""
24
  is_clear: bool = Field(description="Indicates if the user's question is clear and answerable")
config.py CHANGED
@@ -4,7 +4,10 @@ configs = {
4
  "DATA_PATH": "./docs/markdowns",
5
  "PERSIST_PATH": "./chroma_data",
6
  "EMBEDDING_MODEL_NAME": "sentence-transformers/all-mpnet-base-v2",
7
- "COLLECTION_NAME": "langchain_mpnet_collection"
 
 
 
8
  }
9
 
10
  if __name__ == "__main__":
 
4
  "DATA_PATH": "./docs/markdowns",
5
  "PERSIST_PATH": "./chroma_data",
6
  "EMBEDDING_MODEL_NAME": "sentence-transformers/all-mpnet-base-v2",
7
+ "COLLECTION_NAME": "langchain_mpnet_collection",
8
+ "LLM_MODEL_NAME": "gemini-2.0-flash",
9
+ "TEMPERATURE": 0.2,
10
+ "MAX_TOKENS": 2048,
11
  }
12
 
13
  if __name__ == "__main__":
core/rag_agent.py CHANGED
@@ -1,16 +1,21 @@
1
  import uuid
2
  from langchain_google_genai import ChatGoogleGenerativeAI
3
- import config
4
  from agent.tools import *
5
  from agent.graph import create_agent_graph
6
 
 
 
 
 
7
  class RAGAgent:
8
  def __init__(self):
9
  self.thread_id = str(uuid.uuid4())
10
 
11
  self.llm = ChatGoogleGenerativeAI(
12
- model=config.LLM_MODEL,
13
- temperature=config.LLM_TEMPERATURE
 
14
  )
15
 
16
  vectordb = intialize_chroma_vectorstore()
 
1
  import uuid
2
  from langchain_google_genai import ChatGoogleGenerativeAI
3
+ from config import configs
4
  from agent.tools import *
5
  from agent.graph import create_agent_graph
6
 
7
+ from dotenv import load_dotenv
8
+
9
+ load_dotenv()
10
+
11
  class RAGAgent:
12
  def __init__(self):
13
  self.thread_id = str(uuid.uuid4())
14
 
15
  self.llm = ChatGoogleGenerativeAI(
16
+ model=configs["LLM_MODEL_NAME"],
17
+ temperature=configs["TEMPERATURE"],
18
+ max_tokens=configs["MAX_TOKENS"]
19
  )
20
 
21
  vectordb = intialize_chroma_vectorstore()
requirements.txt CHANGED
@@ -13,4 +13,5 @@ langchain-community
13
  langchain_text_splitters
14
  pymupdf-layout
15
  sentence_transformers
16
- gradio
 
 
13
  langchain_text_splitters
14
  pymupdf-layout
15
  sentence_transformers
16
+ gradio
17
+ python-dotenv
test_scripts.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test script for RAG Agent logic.
3
+ Tests the agent workflow, nodes, state management, and retrieval.
4
+ """
5
+
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ # Add project root to path
10
+ sys.path.insert(0, str(Path(__file__).parent))
11
+
12
+ from langchain_core.messages import HumanMessage, AIMessage
13
+ from agent.state import AgentState
14
+ from core.rag_agent import RAGAgent
15
+
16
+
17
+ def print_separator(title: str):
18
+ """Print a visual separator."""
19
+ print("\n" + "="*70)
20
+ print(f" {title}")
21
+ print("="*70 + "\n")
22
+
23
+
24
+ def test_agent_initialization():
25
+ """Test RAGAgent can be initialized properly."""
26
+ print_separator("TEST 1: Agent Initialization")
27
+
28
+ try:
29
+ agent = RAGAgent()
30
+ print("βœ“ RAGAgent initialized successfully")
31
+ print(f" - Thread ID: {agent.thread_id}")
32
+ print(f" - LLM Model: {agent.llm.model_name if hasattr(agent.llm, 'model_name') else 'initialized'}")
33
+ print(f" - Graph: {type(agent.agent_graph).__name__}")
34
+ return agent
35
+ except Exception as e:
36
+ print(f"βœ— Failed to initialize RAGAgent: {e}")
37
+ import traceback
38
+ traceback.print_exc()
39
+ return None
40
+
41
+
42
+ def test_simple_query(agent: RAGAgent):
43
+ """Test a simple query execution."""
44
+ print_separator("TEST 2: Simple Query")
45
+
46
+ if agent is None:
47
+ print("βœ— Skipping - agent not initialized")
48
+ return False
49
+
50
+ try:
51
+ query = "What is DeepAnalyze?"
52
+ print(f"Query: '{query}'")
53
+
54
+ initial_state = {
55
+ "messages": [HumanMessage(content=query)],
56
+ }
57
+
58
+ result = agent.agent_graph.invoke(
59
+ initial_state,
60
+ config=agent.get_config()
61
+ )
62
+
63
+ messages = result.get("messages", [])
64
+ ai_messages = [m for m in messages if isinstance(m, AIMessage)]
65
+
66
+ if ai_messages:
67
+ print(f"βœ“ Query executed successfully")
68
+ print(f" Total messages: {len(messages)}")
69
+ print(f" Response length: {len(ai_messages[-1].content)} chars")
70
+ print(f"\n Response preview:")
71
+ print(f" {ai_messages[-1].content[:300]}...")
72
+ return True
73
+ else:
74
+ print(f"βœ— No AI response generated")
75
+ return False
76
+
77
+ except Exception as e:
78
+ print(f"βœ— Query execution failed: {e}")
79
+ import traceback
80
+ traceback.print_exc()
81
+ return False
82
+
83
+
84
+ def test_rag_query(agent: RAGAgent):
85
+ """Test a query that should use RAG (local documents)."""
86
+ print_separator("TEST 3: RAG Query")
87
+
88
+ if agent is None:
89
+ print("βœ— Skipping - agent not initialized")
90
+ return False
91
+
92
+ try:
93
+ query = "Explain the architecture of SAM 3"
94
+ print(f"Query: '{query}' (should use local documents)")
95
+
96
+ initial_state = {
97
+ "messages": [HumanMessage(content=query)],
98
+ }
99
+
100
+ result = agent.agent_graph.invoke(
101
+ initial_state,
102
+ config=agent.get_config()
103
+ )
104
+
105
+ messages = result.get("messages", [])
106
+ rag_method = result.get("rag_method", "UNKNOWN")
107
+ ai_messages = [m for m in messages if isinstance(m, AIMessage)]
108
+
109
+ print(f" Routing decision: {rag_method}")
110
+
111
+ if ai_messages:
112
+ print(f"βœ“ RAG query executed")
113
+ print(f" Response preview:")
114
+ print(f" {ai_messages[-1].content[:300]}...")
115
+ return True
116
+ else:
117
+ print(f"βœ— No response generated")
118
+ return False
119
+
120
+ except Exception as e:
121
+ print(f"βœ— RAG query failed: {e}")
122
+ import traceback
123
+ traceback.print_exc()
124
+ return False
125
+
126
+
127
+ def test_web_search_query(agent: RAGAgent):
128
+ """Test a query that should use web search."""
129
+ print_separator("TEST 4: Web Search Query")
130
+
131
+ if agent is None:
132
+ print("βœ— Skipping - agent not initialized")
133
+ return False
134
+
135
+ try:
136
+ query = "What's the latest news about AI in 2025?"
137
+ print(f"Query: '{query}' (should use web search)")
138
+
139
+ initial_state = {
140
+ "messages": [HumanMessage(content=query)],
141
+ }
142
+
143
+ result = agent.agent_graph.invoke(
144
+ initial_state,
145
+ config=agent.get_config()
146
+ )
147
+
148
+ messages = result.get("messages", [])
149
+ rag_method = result.get("rag_method", "UNKNOWN")
150
+ ai_messages = [m for m in messages if isinstance(m, AIMessage)]
151
+
152
+ print(f" Routing decision: {rag_method}")
153
+
154
+ if ai_messages:
155
+ print(f"βœ“ Web search query executed")
156
+ print(f" Response preview:")
157
+ print(f" {ai_messages[-1].content[:300]}...")
158
+ return True
159
+ else:
160
+ print(f"βœ— No response generated")
161
+ return False
162
+
163
+ except Exception as e:
164
+ print(f"βœ— Web search query failed: {e}")
165
+ import traceback
166
+ traceback.print_exc()
167
+ return False
168
+
169
+
170
+ def test_general_query(agent: RAGAgent):
171
+ """Test a general query that doesn't need RAG or web search."""
172
+ print_separator("TEST 5: General Query")
173
+
174
+ if agent is None:
175
+ print("βœ— Skipping - agent not initialized")
176
+ return False
177
+
178
+ try:
179
+ query = "What is 15 multiplied by 7?"
180
+ print(f"Query: '{query}' (should use general LLM)")
181
+
182
+ initial_state = {
183
+ "messages": [HumanMessage(content=query)],
184
+ }
185
+
186
+ result = agent.agent_graph.invoke(
187
+ initial_state,
188
+ config=agent.get_config()
189
+ )
190
+
191
+ messages = result.get("messages", [])
192
+ rag_method = result.get("rag_method", "UNKNOWN")
193
+ ai_messages = [m for m in messages if isinstance(m, AIMessage)]
194
+
195
+ print(f" Routing decision: {rag_method}")
196
+
197
+ if ai_messages:
198
+ print(f"βœ“ General query executed")
199
+ print(f" Response: {ai_messages[-1].content}")
200
+ return True
201
+ else:
202
+ print(f"βœ— No response generated")
203
+ return False
204
+
205
+ except Exception as e:
206
+ print(f"βœ— General query failed: {e}")
207
+ import traceback
208
+ traceback.print_exc()
209
+ return False
210
+
211
+
212
+ def test_conversation_memory(agent: RAGAgent):
213
+ """Test multi-turn conversation with memory."""
214
+ print_separator("TEST 6: Conversation Memory")
215
+
216
+ if agent is None:
217
+ print("βœ— Skipping - agent not initialized")
218
+ return False
219
+
220
+ try:
221
+ # Reset thread for clean test
222
+ agent.reset_thread()
223
+
224
+ # First turn
225
+ print("Turn 1: 'What is DeepAnalyze?'")
226
+ state1 = {
227
+ "messages": [HumanMessage(content="What is DeepAnalyze?")],
228
+ }
229
+ result1 = agent.agent_graph.invoke(state1, config=agent.get_config())
230
+
231
+ ai_msg_1 = [m for m in result1["messages"] if isinstance(m, AIMessage)]
232
+ if not ai_msg_1:
233
+ print("βœ— No response in turn 1")
234
+ return False
235
+
236
+ print(f"βœ“ Turn 1 response: {ai_msg_1[-1].content[:100]}...")
237
+
238
+ # Second turn - follow-up question
239
+ print("\nTurn 2: 'What are its main features?' (requires context)")
240
+ state2 = {
241
+ "messages": [HumanMessage(content="What are its main features?")],
242
+ }
243
+ result2 = agent.agent_graph.invoke(state2, config=agent.get_config())
244
+
245
+ ai_msg_2 = [m for m in result2["messages"] if isinstance(m, AIMessage)]
246
+ if not ai_msg_2:
247
+ print("βœ— No response in turn 2")
248
+ return False
249
+
250
+ print(f"βœ“ Turn 2 response: {ai_msg_2[-1].content[:100]}...")
251
+
252
+ # Check if response makes sense in context
253
+ response = ai_msg_2[-1].content.lower()
254
+ if "deepanalyze" in response or "feature" in response or "agent" in response:
255
+ print("βœ“ Conversation memory working - response uses context")
256
+ return True
257
+ else:
258
+ print("⚠ Response may not be using conversation context properly")
259
+ return True # Still pass, as it generated a response
260
+
261
+ except Exception as e:
262
+ print(f"βœ— Conversation memory test failed: {e}")
263
+ import traceback
264
+ traceback.print_exc()
265
+ return False
266
+
267
+
268
+ def test_thread_reset(agent: RAGAgent):
269
+ """Test thread reset functionality."""
270
+ print_separator("TEST 7: Thread Reset")
271
+
272
+ if agent is None:
273
+ print("βœ— Skipping - agent not initialized")
274
+ return False
275
+
276
+ try:
277
+ old_thread_id = agent.thread_id
278
+ print(f"Old thread ID: {old_thread_id}")
279
+
280
+ agent.reset_thread()
281
+
282
+ new_thread_id = agent.thread_id
283
+ print(f"New thread ID: {new_thread_id}")
284
+
285
+ if old_thread_id != new_thread_id:
286
+ print("βœ“ Thread reset successfully")
287
+ return True
288
+ else:
289
+ print("βœ— Thread ID unchanged after reset")
290
+ return False
291
+
292
+ except Exception as e:
293
+ print(f"βœ— Thread reset failed: {e}")
294
+ import traceback
295
+ traceback.print_exc()
296
+ return False
297
+
298
+
299
+ def run_all_tests():
300
+ """Run all tests and provide summary."""
301
+ print("\n" + "β–ˆ"*70)
302
+ print(" RAG AGENT TEST SUITE")
303
+ print("β–ˆ"*70)
304
+
305
+ # Initialize agent once
306
+ agent = test_agent_initialization()
307
+
308
+ if agent is None:
309
+ print("\nβœ— Cannot proceed - agent initialization failed")
310
+ return False
311
+
312
+ tests = [
313
+ ("Simple Query", lambda: test_simple_query(agent)),
314
+ ("RAG Query", lambda: test_rag_query(agent)),
315
+ ("Web Search Query", lambda: test_web_search_query(agent)),
316
+ ("General Query", lambda: test_general_query(agent)),
317
+ ("Conversation Memory", lambda: test_conversation_memory(agent)),
318
+ ("Thread Reset", lambda: test_thread_reset(agent)),
319
+ ]
320
+
321
+ results = {}
322
+ for name, test_func in tests:
323
+ try:
324
+ results[name] = test_func()
325
+ except Exception as e:
326
+ print(f"\nβœ— Test '{name}' crashed: {e}")
327
+ import traceback
328
+ traceback.print_exc()
329
+ results[name] = False
330
+
331
+ # Print summary
332
+ print_separator("TEST SUMMARY")
333
+ passed = sum(results.values())
334
+ total = len(results)
335
+
336
+ for name, passed_test in results.items():
337
+ status = "βœ“ PASS" if passed_test else "βœ— FAIL"
338
+ print(f"{status}: {name}")
339
+
340
+ print(f"\n{'='*70}")
341
+ print(f" TOTAL: {passed}/{total} tests passed ({passed/total*100:.1f}%)")
342
+ print(f"{'='*70}\n")
343
+
344
+ return passed == total
345
+
346
+
347
+ if __name__ == "__main__":
348
+ success = run_all_tests()
349
+ sys.exit(0 if success else 1)
ui/gradio_components.py CHANGED
@@ -1,4 +1,5 @@
1
  import gradio as gr
 
2
  from core.rag_agent import RAGAgent
3
 
4
  # Initialize components
@@ -19,81 +20,143 @@ def chat_with_agent(message, history):
19
  try:
20
  agent = initialize_agent()
21
 
22
- # Stream the agent's response
23
- response_text = ""
24
- for event in agent.agent_graph.stream(
25
- {"messages": [("user", message)]},
26
- agent.get_config(),
27
- stream_mode="values"
28
- ):
29
- if "messages" in event and len(event["messages"]) > 0:
30
- last_message = event["messages"][-1]
31
- if hasattr(last_message, "content"):
32
- response_text = last_message.content
33
 
34
- if not response_text:
35
- response_text = "I apologize, but I couldn't generate a response. Please try again."
36
 
37
- return response_text
38
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  except Exception as e:
40
- return f"Error: {str(e)}"
 
 
 
 
 
 
41
 
42
  def reset_conversation():
43
  """Reset the conversation thread"""
44
  global rag_agent
45
  if rag_agent:
46
  rag_agent.reset_thread()
47
- return None # Clear chat history
48
 
49
  def create_gradio_ui():
50
  """Create the complete Gradio interface"""
51
 
52
- with gr.Blocks(title="RAG Agent with Agentic Memory", theme=gr.themes.Soft()) as demo:
53
  gr.Markdown("""
54
  # πŸ€– RAG Agent with Agentic Memory
55
 
56
  Chat with an intelligent agent that uses:
57
- - πŸ“š **Local Knowledge Base** (ChromaDB)
58
- - πŸ” **Web Search** (Tavily)
59
- - πŸ“– **Wikipedia**
60
- - πŸŽ“ **ArXiv** (Academic Papers)
61
  """)
62
 
63
- gr.Markdown("### Chat with Your Documents")
64
- gr.Markdown("Ask questions about your documents or any topic. The agent will search multiple sources.")
65
-
66
- chatbot = gr.Chatbot(
67
- label="Conversation",
68
- height=500,
69
- show_label=True,
70
- avatar_images=(None, "πŸ€–")
71
- )
72
-
73
  with gr.Row():
74
- msg = gr.Textbox(
75
- label="Your Message",
76
- placeholder="Ask me anything about your documents or general knowledge...",
77
- scale=4
78
- )
79
- submit_btn = gr.Button("Send", variant="primary", scale=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
- with gr.Row():
82
- clear_chat_btn = gr.Button("πŸ”„ Reset Conversation")
83
- gr.Markdown("*Note: Resetting clears the conversation history*")
 
 
 
 
 
 
 
84
 
85
- # Chat interface
86
- chat_interface = gr.ChatInterface(
87
- fn=chat_with_agent,
88
- chatbot=chatbot,
89
- textbox=msg,
90
- submit_btn=submit_btn,
91
- retry_btn=None,
92
- undo_btn=None,
93
- clear_btn=None
94
  )
95
 
96
- clear_chat_btn.click(
 
 
 
 
 
 
97
  fn=reset_conversation,
98
  outputs=[chatbot]
99
  )
@@ -101,16 +164,27 @@ def create_gradio_ui():
101
  gr.Markdown("""
102
  ---
103
  ### πŸ”§ How it works:
104
- 1. **Ask questions** in the chat
105
  2. The agent will:
106
- - Analyze your query
107
- - Search relevant sources (ChromaDB, Web, Wikipedia, ArXiv)
108
- - Provide comprehensive answers with citations
109
- 3. Use **Reset Conversation** to start fresh
 
 
 
 
110
  """)
111
 
112
  return demo
113
 
114
  if __name__ == "__main__":
115
  demo = create_gradio_ui()
116
- demo.launch(share=False, server_name="127.0.0.1", server_port=7860)
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from langchain_core.messages import HumanMessage, AIMessage
3
  from core.rag_agent import RAGAgent
4
 
5
  # Initialize components
 
20
  try:
21
  agent = initialize_agent()
22
 
23
+ # Convert Gradio history format to LangChain messages
24
+ messages = []
25
+ for user_msg, assistant_msg in history:
26
+ messages.append(HumanMessage(content=user_msg))
27
+ if assistant_msg:
28
+ messages.append(AIMessage(content=assistant_msg))
 
 
 
 
 
29
 
30
+ # Add current user message
31
+ messages.append(HumanMessage(content=message))
32
 
33
+ # Create initial state
34
+ initial_state = {
35
+ "messages": messages,
36
+ }
37
+
38
+ # Invoke the agent graph
39
+ result = agent.agent_graph.invoke(
40
+ initial_state,
41
+ config=agent.get_config()
42
+ )
43
+
44
+ # Extract AI response
45
+ result_messages = result.get("messages", [])
46
+ ai_messages = [m for m in result_messages if isinstance(m, AIMessage)]
47
+
48
+ if ai_messages:
49
+ # Get the last AI message
50
+ response = ai_messages[-1].content
51
+
52
+ # Add routing info as metadata (optional)
53
+ rag_method = result.get("rag_method", "UNKNOWN")
54
+ response_with_metadata = f"{response}\n\n*[Source: {rag_method}]*"
55
+
56
+ # Return history in Gradio's format [[user, bot], [user, bot], ...]
57
+ new_history = history + [[message, response_with_metadata]]
58
+ return new_history
59
+ else:
60
+ new_history = history + [[message, "⚠️ No response generated. Please try again."]]
61
+ return new_history
62
+
63
  except Exception as e:
64
+ error_msg = f"❌ Error: {str(e)}"
65
+ print(f"Chat error: {e}")
66
+ import traceback
67
+ traceback.print_exc()
68
+
69
+ new_history = history + [[message, error_msg]]
70
+ return new_history
71
 
72
  def reset_conversation():
73
  """Reset the conversation thread"""
74
  global rag_agent
75
  if rag_agent:
76
  rag_agent.reset_thread()
77
+ return [] # Clear chat history
78
 
79
  def create_gradio_ui():
80
  """Create the complete Gradio interface"""
81
 
82
+ with gr.Blocks(title="RAG Agent with Agentic Memory") as demo:
83
  gr.Markdown("""
84
  # πŸ€– RAG Agent with Agentic Memory
85
 
86
  Chat with an intelligent agent that uses:
87
+ - πŸ“š **Local Knowledge Base** (ChromaDB) - Research papers on DeepAnalyze, AgentMem, SAM3, etc.
88
+ - πŸ” **Web Search** (Tavily) - Real-time information and current events
89
+ - πŸ“– **Wikipedia** - General knowledge
90
+ - πŸŽ“ **ArXiv** - Academic papers
91
  """)
92
 
 
 
 
 
 
 
 
 
 
 
93
  with gr.Row():
94
+ with gr.Column(scale=4):
95
+ gr.Markdown("### πŸ’¬ Chat Interface")
96
+
97
+ chatbot = gr.Chatbot(
98
+ label="Conversation",
99
+ height=500,
100
+ show_label=False,
101
+ )
102
+
103
+ with gr.Row():
104
+ msg = gr.Textbox(
105
+ label="Your Message",
106
+ placeholder="Ask me anything about your documents or general knowledge...",
107
+ scale=5,
108
+ show_label=False
109
+ )
110
+ submit_btn = gr.Button("Send πŸ“€", variant="primary", scale=1)
111
+
112
+ with gr.Row():
113
+ clear_btn = gr.Button("πŸ”„ Reset Conversation", variant="secondary")
114
+
115
+ with gr.Column(scale=1):
116
+ gr.Markdown("### πŸ“Š Agent Status")
117
+ status_box = gr.Markdown("*Ready*")
118
+
119
+ gr.Markdown("### πŸ’‘ Example Queries")
120
+ gr.Markdown("""
121
+ **Local Documents (RAG):**
122
+ - What is DeepAnalyze?
123
+ - Explain SAM 3 architecture
124
+ - What is AgentMem?
125
+
126
+ **Web Search:**
127
+ - Latest AI news in 2025
128
+ - Current events in technology
129
+
130
+ **General:**
131
+ - What is 15 Γ— 7?
132
+ - Explain machine learning
133
+ """)
134
 
135
+ # Event handlers
136
+ def submit_message(message, history):
137
+ """Handle message submission with status update"""
138
+ if not message.strip():
139
+ return history, ""
140
+
141
+ # Get response
142
+ new_history = chat_with_agent(message, history)
143
+
144
+ return new_history, ""
145
 
146
+ # Wire up events
147
+ msg.submit(
148
+ fn=submit_message,
149
+ inputs=[msg, chatbot],
150
+ outputs=[chatbot, msg]
 
 
 
 
151
  )
152
 
153
+ submit_btn.click(
154
+ fn=submit_message,
155
+ inputs=[msg, chatbot],
156
+ outputs=[chatbot, msg]
157
+ )
158
+
159
+ clear_btn.click(
160
  fn=reset_conversation,
161
  outputs=[chatbot]
162
  )
 
164
  gr.Markdown("""
165
  ---
166
  ### πŸ”§ How it works:
167
+ 1. **Type your question** in the text box
168
  2. The agent will:
169
+ - 🧠 Analyze your query to determine the best source
170
+ - πŸ” Search relevant sources (Local docs, Web, Wikipedia)
171
+ - πŸ“ Generate a comprehensive answer
172
+ - πŸ’Ύ Remember conversation context for follow-up questions
173
+ 3. Use **Reset Conversation** to start a new thread
174
+
175
+ ---
176
+ *Powered by LangGraph + LangChain + ChromaDB + Anthropic Claude*
177
  """)
178
 
179
  return demo
180
 
181
  if __name__ == "__main__":
182
  demo = create_gradio_ui()
183
+ print("πŸš€ Starting Gradio interface...")
184
+ print("πŸ“ Running on: http://127.0.0.1:7860")
185
+ demo.launch(
186
+ share=False,
187
+ server_name="127.0.0.1",
188
+ server_port=7860,
189
+ show_error=True
190
+ )