mtyrrell commited on
Commit
ea45e0c
·
1 Parent(s): 9374fdf
Files changed (2) hide show
  1. app/main.py +4 -380
  2. app/nodes.py +305 -1
app/main.py CHANGED
@@ -1,31 +1,26 @@
1
  #CHATFED_ORCHESTRATOR
2
  import gradio as gr
3
- from fastapi import FastAPI, UploadFile, File, Form, Request
4
  from fastapi.responses import StreamingResponse
5
  from langserve import add_routes
6
  from langgraph.graph import StateGraph, START, END
7
- from typing import Optional, Dict, Any, List
8
  from typing_extensions import TypedDict
9
- from gradio_client import Client, file
10
  import uvicorn
11
  import os
12
  from datetime import datetime
13
  import logging
14
  from contextlib import asynccontextmanager
15
- # import threading
16
  from langchain_core.runnables import RunnableLambda
17
- # import tempfile
18
- # import mimetypes
19
  import asyncio
20
  from typing import Generator
21
  import json
22
  import httpx
23
- # import ast
24
  from functools import partial
25
 
26
  from utils import getconfig, convert_context_to_list
27
- from nodes import detect_file_type_node, ingest_node, geojson_direct_result_node, retrieve_node
28
- from models import GraphState, ChatFedInput, ChatFedOutput, ChatUIInput
29
 
30
  config = getconfig("params.cfg")
31
  RETRIEVER = config.get("retriever", "RETRIEVER", fallback="https://giz-chatfed-retriever.hf.space")
@@ -39,250 +34,6 @@ logger = logging.getLogger(__name__)
39
 
40
 
41
 
42
- # MAIN STREAMING GENERATOR
43
- async def generate_node_streaming(state: GraphState) -> Generator[GraphState, None, None]:
44
- """Streaming version that calls generator's FastAPI endpoint"""
45
- start_time = datetime.now()
46
- logger.info(f"Generation (streaming): {state['query'][:50]}...")
47
-
48
- try:
49
- # Get MAX_CONTEXT_CHARS at the beginning so it's available throughout the function
50
- MAX_CONTEXT_CHARS = int(config.get("general", "MAX_CONTEXT_CHARS"))
51
-
52
- # Combine retriever context with ingestor context
53
- retrieved_context = state.get("context", "")
54
- ingestor_context = state.get("ingestor_context", "")
55
-
56
- # Convert contexts to list format expected by generator
57
- context_list = []
58
-
59
- if ingestor_context:
60
- # Add ingestor context
61
- context_list.append({
62
- "answer": ingestor_context,
63
- "answer_metadata": {
64
- "filename": state.get("filename", "Uploaded Document"),
65
- "page": "Unknown",
66
- "year": "Unknown",
67
- "source": "Ingestor"
68
- }
69
- })
70
-
71
- if retrieved_context:
72
- # Convert retrieved context to list and add
73
- retrieved_list = convert_context_to_list(retrieved_context)
74
- context_list.extend(retrieved_list)
75
-
76
- # Prepare the request payload
77
- payload = {
78
- "query": state["query"],
79
- "context": context_list
80
- }
81
-
82
- # Determine generator URL - handle both Hugging Face and direct URLs
83
- generator_url = GENERATOR
84
-
85
- if not generator_url.startswith('http'):
86
- # Allows for easy specification of space in config (converts to URL)
87
- # Replace '/' with '-' for Hugging Face space URLs
88
- # Force the replacement to ensure it works
89
- space_name = generator_url.replace('/', '-').replace('_', '-')
90
- generator_url = f"https://{space_name}.hf.space"
91
-
92
-
93
- # Try FastAPI endpoint first, fallback to Gradio if needed
94
- fastapi_success = False
95
-
96
- try:
97
- # Make streaming request to generator's FastAPI endpoint
98
- async with httpx.AsyncClient(timeout=300.0, verify=False) as client:
99
-
100
- async with client.stream(
101
- "POST",
102
- f"{generator_url}/generate/stream",
103
- json=payload,
104
- headers={"Content-Type": "application/json"}
105
- ) as response:
106
- if response.status_code != 200:
107
- error_text = await response.aread()
108
- raise Exception(f"FastAPI endpoint returned status {response.status_code}")
109
-
110
- current_text = ""
111
- sources = None
112
- event_type = None
113
-
114
- async for line in response.aiter_lines():
115
- if not line.strip():
116
- continue
117
-
118
- # Parse SSE format
119
- if line.startswith("event: "):
120
- event_type = line[7:].strip()
121
- continue
122
- elif line.startswith("data: "):
123
- data_content = line[6:].strip()
124
-
125
- if event_type == "data":
126
- # Text chunk
127
- try:
128
- chunk = json.loads(data_content)
129
- if isinstance(chunk, str):
130
- current_text += chunk
131
-
132
- metadata = state.get("metadata", {})
133
- metadata.update({
134
- "generation_duration": (datetime.now() - start_time).total_seconds(),
135
- "result_length": len(current_text),
136
- "generation_success": True,
137
- "streaming": True,
138
- "generator_type": "fastapi"
139
- })
140
-
141
- yield {
142
- "result": chunk, # Send only the new chunk
143
- "metadata": metadata
144
- }
145
- except json.JSONDecodeError:
146
- # Handle plain text chunks
147
- current_text += data_content
148
-
149
- metadata = state.get("metadata", {})
150
- metadata.update({
151
- "generation_duration": (datetime.now() - start_time).total_seconds(),
152
- "result_length": len(current_text),
153
- "generation_success": True,
154
- "streaming": True,
155
- "generator_type": "fastapi"
156
- })
157
-
158
- yield {
159
- "result": data_content,
160
- "metadata": metadata
161
- }
162
-
163
- elif event_type == "sources":
164
- # Sources data
165
- try:
166
- sources_data = json.loads(data_content)
167
- sources = sources_data.get("sources", [])
168
-
169
- # Update state with sources
170
- metadata = state.get("metadata", {})
171
- metadata.update({
172
- "sources_received": True,
173
- "sources_count": len(sources)
174
- })
175
-
176
- yield {
177
- "sources": sources,
178
- "metadata": metadata
179
- }
180
- except json.JSONDecodeError:
181
- logger.warning(f"Failed to parse sources data: {data_content}")
182
-
183
- elif event_type == "end":
184
- # Stream ended
185
- logger.info("Generator stream ended")
186
- fastapi_success = True
187
- break
188
-
189
- elif event_type == "error":
190
- # Error occurred
191
- try:
192
- error_data = json.loads(data_content)
193
- raise Exception(error_data.get("error", "Unknown error"))
194
- except json.JSONDecodeError:
195
- raise Exception(data_content)
196
-
197
- # GRADIO FALLBACK
198
- except Exception as fastapi_error:
199
- logger.warning(f"FastAPI endpoint failed: {fastapi_error}")
200
- logger.info("Falling back to Gradio client")
201
-
202
- # # Fallback to Gradio client
203
- # try:
204
- # from gradio_client import Client
205
-
206
- # # Convert context back to string for Gradio
207
- # combined_context = ""
208
- # if ingestor_context and retrieved_context:
209
- # # Limit context size to prevent token overflow
210
- # ingestor_truncated = ingestor_context[:MAX_CONTEXT_CHARS//2] if len(ingestor_context) > MAX_CONTEXT_CHARS//2 else ingestor_context
211
- # retrieved_truncated = retrieved_context[:MAX_CONTEXT_CHARS//2] if len(retrieved_context) > MAX_CONTEXT_CHARS//2 else retrieved_context
212
- # combined_context = f"=== UPLOADED DOCUMENT CONTEXT ===\n{ingestor_truncated}\n\n=== RETRIEVED CONTEXT ===\n{retrieved_truncated}"
213
- # elif ingestor_context:
214
- # ingestor_truncated = ingestor_context[:MAX_CONTEXT_CHARS] if len(ingestor_context) > MAX_CONTEXT_CHARS else ingestor_context
215
- # combined_context = f"=== UPLOADED DOCUMENT CONTEXT ===\n{ingestor_truncated}"
216
- # elif retrieved_context:
217
- # combined_context = retrieved_context[:MAX_CONTEXT_CHARS] if len(retrieved_context) > MAX_CONTEXT_CHARS else retrieved_context
218
-
219
- # logger.info(f"Using Gradio client for generator at: {generator_url}")
220
- # client = Client(generator_url)
221
-
222
- # # Use streaming prediction
223
- # job = client.submit(
224
- # query=state["query"],
225
- # context=combined_context,
226
- # api_name="/generate"
227
- # )
228
-
229
- # # Track previous result to send only deltas
230
- # previous_result = ""
231
-
232
- # # Stream the results - each result is likely the full accumulated response
233
- # for result in job:
234
- # if result is not None:
235
- # current_result = result
236
-
237
- # # Calculate the delta (new content only)
238
- # if len(current_result) > len(previous_result):
239
- # delta = current_result[len(previous_result):]
240
- # previous_result = current_result
241
-
242
- # # Yield only the new content
243
- # metadata = state.get("metadata", {})
244
- # metadata.update({
245
- # "generation_duration": (datetime.now() - start_time).total_seconds(),
246
- # "result_length": len(current_result),
247
- # "generation_success": True,
248
- # "streaming": True,
249
- # "generator_type": "gradio_fallback"
250
- # })
251
-
252
- # yield {
253
- # "result": delta, # Send only the delta, not full result
254
- # "metadata": metadata
255
- # }
256
-
257
- # fastapi_success = True # Mark as successful since Gradio worked
258
-
259
- # except Exception as gradio_error:
260
- # logger.error(f"Both FastAPI and Gradio failed. FastAPI: {fastapi_error}, Gradio: {gradio_error}")
261
- # raise Exception(f"Both generation methods failed. FastAPI: {fastapi_error}, Gradio: {gradio_error}")
262
-
263
- # if not fastapi_success:
264
- # raise Exception("Generation failed - no successful response received")
265
-
266
- except Exception as e:
267
- duration = (datetime.now() - start_time).total_seconds()
268
- logger.error(f"Streaming generation failed: {str(e)}")
269
-
270
- metadata = state.get("metadata", {})
271
- metadata.update({
272
- "generation_duration": duration,
273
- "generation_success": False,
274
- "generation_error": str(e),
275
- "streaming": True
276
- })
277
- yield {"result": f"Error: {str(e)}", "metadata": metadata}
278
-
279
- # Conditional routing function
280
- def route_workflow(state: GraphState) -> str:
281
- """Route to appropriate workflow based on file type"""
282
- workflow_type = state.get("workflow_type", "standard")
283
- return workflow_type
284
-
285
-
286
  #----------------------------------------
287
  # CORE WORKFLOW GRAPH
288
  #----------------------------------------
@@ -318,134 +69,7 @@ workflow.add_edge("geojson_direct", END)
318
  compiled_graph = workflow.compile()
319
 
320
 
321
- async def process_query_streaming(query: str, file_upload, reports_filter: str = "", sources_filter: str = "",
322
- subtype_filter: str = "", year_filter: str = "",
323
- output_format: str = "structured"):
324
- """
325
- Unified streaming function that yields partial results
326
-
327
- Args:
328
- output_format: "structured" for dict format, "gradio" for plain text format
329
- """
330
- file_content = None
331
- filename = None
332
-
333
- if file_upload is not None:
334
- try:
335
- with open(file_upload.name, 'rb') as f:
336
- file_content = f.read()
337
- filename = os.path.basename(file_upload.name)
338
- logger.info(f"File uploaded: {filename}, size: {len(file_content)} bytes")
339
- except Exception as e:
340
- logger.error(f"Error reading uploaded file: {str(e)}")
341
- if output_format == "structured":
342
- yield {"type": "error", "content": f"Error reading file: {str(e)}"}
343
- else:
344
- yield f"Error reading file: {str(e)}"
345
- return
346
-
347
- start_time = datetime.now()
348
- session_id = f"gradio_{start_time.strftime('%Y%m%d_%H%M%S')}"
349
-
350
- try:
351
- # Process ingestion first (non-streaming)
352
- initial_state = {
353
- "query": query,
354
- "context": "",
355
- "ingestor_context": "",
356
- "result": "",
357
- "sources": [],
358
- "reports_filter": reports_filter or "",
359
- "sources_filter": sources_filter or "",
360
- "subtype_filter": subtype_filter or "",
361
- "year_filter": year_filter or "",
362
- "file_content": file_content,
363
- "filename": filename,
364
- "file_type": "unknown",
365
- "workflow_type": "standard",
366
- "metadata": {
367
- "session_id": session_id,
368
- "start_time": start_time.isoformat(),
369
- "has_file_attachment": file_content is not None
370
- }
371
- }
372
-
373
- # Detect file type - merge the returned state with initial state
374
- state_after_detect = {**initial_state, **detect_file_type_node(initial_state)}
375
-
376
- # Ingest if file provided - merge the returned state
377
- state_after_ingest = {**state_after_detect, **ingest_node(state_after_detect)}
378
-
379
- # Route workflow
380
- workflow_type = route_workflow(state_after_ingest)
381
-
382
- if workflow_type == "geojson_direct":
383
- # For GeoJSON, return direct result
384
- final_state = geojson_direct_result_node(state_after_ingest)
385
- if output_format == "structured":
386
- yield {"type": "data", "content": final_state["result"]}
387
- yield {"type": "end", "content": ""}
388
- else:
389
- yield final_state["result"]
390
- else:
391
- # For standard workflow, retrieve first - merge the returned state
392
- state_after_retrieve = {**state_after_ingest, **retrieve_node(state_after_ingest)}
393
-
394
- # Initialize variables for both output formats
395
- sources_collected = None
396
- accumulated_response = "" if output_format == "gradio" else None
397
-
398
- # Then stream generation
399
- async for partial_state in generate_node_streaming(state_after_retrieve):
400
- if "result" in partial_state:
401
- if output_format == "structured":
402
- yield {"type": "data", "content": partial_state["result"]}
403
- else:
404
- # Accumulate the content and yield the full accumulated response
405
- accumulated_response += partial_state["result"]
406
- yield accumulated_response
407
-
408
- # Collect sources for later
409
- if "sources" in partial_state:
410
- sources_collected = partial_state["sources"]
411
-
412
- # Handle sources based on output format
413
- if sources_collected:
414
- if output_format == "structured":
415
- yield {"type": "sources", "content": sources_collected}
416
- else:
417
- # Append sources to accumulated response
418
- sources_text = "\n\n**Sources:**\n"
419
- for i, source in enumerate(sources_collected, 1):
420
- if isinstance(source, dict):
421
- title = source.get('title', 'Unknown')
422
- link = source.get('link', '#')
423
- sources_text += f"{i}. [{title}]({link})\n"
424
- else:
425
- sources_text += f"{i}. {source}\n"
426
-
427
- accumulated_response += sources_text
428
- yield accumulated_response
429
-
430
- if output_format == "structured":
431
- yield {"type": "end", "content": ""}
432
-
433
- except Exception as e:
434
- logger.error(f"Streaming pipeline failed: {str(e)}")
435
- if output_format == "structured":
436
- yield {"type": "error", "content": f"Error: {str(e)}"}
437
- else:
438
- yield f"Error: {str(e)}"
439
 
440
- # # Convenience wrapper for Gradio compatibility
441
- # async def process_query_gradio_streaming(query: str, file_upload, reports_filter: str = "", sources_filter: str = "",
442
- # subtype_filter: str = "", year_filter: str = ""):
443
- # """Streaming version for Gradio UI - wrapper around unified function"""
444
- # async for result in process_query_streaming(
445
- # query, file_upload, reports_filter, sources_filter,
446
- # subtype_filter, year_filter, output_format="gradio"
447
- # ):
448
- # yield result
449
 
450
 
451
  async def chatui_adapter(data):
 
1
  #CHATFED_ORCHESTRATOR
2
  import gradio as gr
3
+ from fastapi import FastAPI, UploadFile, File, Form
4
  from fastapi.responses import StreamingResponse
5
  from langserve import add_routes
6
  from langgraph.graph import StateGraph, START, END
7
+ from typing import Optional
8
  from typing_extensions import TypedDict
 
9
  import uvicorn
10
  import os
11
  from datetime import datetime
12
  import logging
13
  from contextlib import asynccontextmanager
 
14
  from langchain_core.runnables import RunnableLambda
 
 
15
  import asyncio
16
  from typing import Generator
17
  import json
18
  import httpx
 
19
  from functools import partial
20
 
21
  from utils import getconfig, convert_context_to_list
22
+ from nodes import detect_file_type_node, ingest_node, geojson_direct_result_node, retrieve_node, generate_node_streaming, route_workflow, process_query_streaming
23
+ from models import GraphState, ChatUIInput
24
 
25
  config = getconfig("params.cfg")
26
  RETRIEVER = config.get("retriever", "RETRIEVER", fallback="https://giz-chatfed-retriever.hf.space")
 
34
 
35
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  #----------------------------------------
38
  # CORE WORKFLOW GRAPH
39
  #----------------------------------------
 
69
  compiled_graph = workflow.compile()
70
 
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
 
 
 
 
 
 
 
 
 
73
 
74
 
75
  async def chatui_adapter(data):
app/nodes.py CHANGED
@@ -7,6 +7,10 @@ from gradio_client import Client, file
7
  import logging
8
  from utils import getconfig
9
  import dotenv
 
 
 
 
10
 
11
  dotenv.load_dotenv()
12
 
@@ -180,4 +184,304 @@ def retrieve_node(state: GraphState) -> GraphState:
180
  "retrieval_success": False,
181
  "retrieval_error": str(e)
182
  })
183
- return {"context": "", "metadata": metadata}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import logging
8
  from utils import getconfig
9
  import dotenv
10
+ from typing_extensions import TypedDict
11
+ import httpx
12
+ import json
13
+ from typing import Generator
14
 
15
  dotenv.load_dotenv()
16
 
 
184
  "retrieval_success": False,
185
  "retrieval_error": str(e)
186
  })
187
+ return {"context": "", "metadata": metadata}
188
+
189
+
190
+
191
+
192
+ # MAIN STREAMING GENERATOR
193
+ async def generate_node_streaming(state: GraphState) -> Generator[GraphState, None, None]:
194
+ """Streaming version that calls generator's FastAPI endpoint"""
195
+ start_time = datetime.now()
196
+ logger.info(f"Generation (streaming): {state['query'][:50]}...")
197
+
198
+ try:
199
+ # Get MAX_CONTEXT_CHARS at the beginning so it's available throughout the function
200
+ MAX_CONTEXT_CHARS = int(config.get("general", "MAX_CONTEXT_CHARS"))
201
+
202
+ # Combine retriever context with ingestor context
203
+ retrieved_context = state.get("context", "")
204
+ ingestor_context = state.get("ingestor_context", "")
205
+
206
+ # Convert contexts to list format expected by generator
207
+ context_list = []
208
+
209
+ if ingestor_context:
210
+ # Add ingestor context
211
+ context_list.append({
212
+ "answer": ingestor_context,
213
+ "answer_metadata": {
214
+ "filename": state.get("filename", "Uploaded Document"),
215
+ "page": "Unknown",
216
+ "year": "Unknown",
217
+ "source": "Ingestor"
218
+ }
219
+ })
220
+
221
+ if retrieved_context:
222
+ # Convert retrieved context to list and add
223
+ retrieved_list = convert_context_to_list(retrieved_context)
224
+ context_list.extend(retrieved_list)
225
+
226
+ # Prepare the request payload
227
+ payload = {
228
+ "query": state["query"],
229
+ "context": context_list
230
+ }
231
+
232
+ # Determine generator URL - handle both Hugging Face and direct URLs
233
+ generator_url = GENERATOR
234
+
235
+ if not generator_url.startswith('http'):
236
+ # Allows for easy specification of space in config (converts to URL)
237
+ # Replace '/' with '-' for Hugging Face space URLs
238
+ # Force the replacement to ensure it works
239
+ space_name = generator_url.replace('/', '-').replace('_', '-')
240
+ generator_url = f"https://{space_name}.hf.space"
241
+
242
+
243
+ # Try FastAPI endpoint first, fallback to Gradio if needed
244
+ fastapi_success = False
245
+
246
+ try:
247
+ # Make streaming request to generator's FastAPI endpoint
248
+ async with httpx.AsyncClient(timeout=300.0, verify=False) as client:
249
+
250
+ async with client.stream(
251
+ "POST",
252
+ f"{generator_url}/generate/stream",
253
+ json=payload,
254
+ headers={"Content-Type": "application/json"}
255
+ ) as response:
256
+ if response.status_code != 200:
257
+ error_text = await response.aread()
258
+ raise Exception(f"FastAPI endpoint returned status {response.status_code}")
259
+
260
+ current_text = ""
261
+ sources = None
262
+ event_type = None
263
+
264
+ async for line in response.aiter_lines():
265
+ if not line.strip():
266
+ continue
267
+
268
+ # Parse SSE format
269
+ if line.startswith("event: "):
270
+ event_type = line[7:].strip()
271
+ continue
272
+ elif line.startswith("data: "):
273
+ data_content = line[6:].strip()
274
+
275
+ if event_type == "data":
276
+ # Text chunk
277
+ try:
278
+ chunk = json.loads(data_content)
279
+ if isinstance(chunk, str):
280
+ current_text += chunk
281
+
282
+ metadata = state.get("metadata", {})
283
+ metadata.update({
284
+ "generation_duration": (datetime.now() - start_time).total_seconds(),
285
+ "result_length": len(current_text),
286
+ "generation_success": True,
287
+ "streaming": True,
288
+ "generator_type": "fastapi"
289
+ })
290
+
291
+ yield {
292
+ "result": chunk, # Send only the new chunk
293
+ "metadata": metadata
294
+ }
295
+ except json.JSONDecodeError:
296
+ # Handle plain text chunks
297
+ current_text += data_content
298
+
299
+ metadata = state.get("metadata", {})
300
+ metadata.update({
301
+ "generation_duration": (datetime.now() - start_time).total_seconds(),
302
+ "result_length": len(current_text),
303
+ "generation_success": True,
304
+ "streaming": True,
305
+ "generator_type": "fastapi"
306
+ })
307
+
308
+ yield {
309
+ "result": data_content,
310
+ "metadata": metadata
311
+ }
312
+
313
+ elif event_type == "sources":
314
+ # Sources data
315
+ try:
316
+ sources_data = json.loads(data_content)
317
+ sources = sources_data.get("sources", [])
318
+
319
+ # Update state with sources
320
+ metadata = state.get("metadata", {})
321
+ metadata.update({
322
+ "sources_received": True,
323
+ "sources_count": len(sources)
324
+ })
325
+
326
+ yield {
327
+ "sources": sources,
328
+ "metadata": metadata
329
+ }
330
+ except json.JSONDecodeError:
331
+ logger.warning(f"Failed to parse sources data: {data_content}")
332
+
333
+ elif event_type == "end":
334
+ # Stream ended
335
+ logger.info("Generator stream ended")
336
+ fastapi_success = True
337
+ break
338
+
339
+ elif event_type == "error":
340
+ # Error occurred
341
+ try:
342
+ error_data = json.loads(data_content)
343
+ raise Exception(error_data.get("error", "Unknown error"))
344
+ except json.JSONDecodeError:
345
+ raise Exception(data_content)
346
+
347
+
348
+ except Exception as e:
349
+ duration = (datetime.now() - start_time).total_seconds()
350
+ logger.error(f"Streaming generation failed: {str(e)}")
351
+
352
+ metadata = state.get("metadata", {})
353
+ metadata.update({
354
+ "generation_duration": duration,
355
+ "generation_success": False,
356
+ "generation_error": str(e),
357
+ "streaming": True
358
+ })
359
+ yield {"result": f"Error: {str(e)}", "metadata": metadata}
360
+
361
+ # Conditional routing function
362
+ def route_workflow(state: GraphState) -> str:
363
+ """Route to appropriate workflow based on file type"""
364
+ workflow_type = state.get("workflow_type", "standard")
365
+ return workflow_type
366
+
367
+
368
+
369
+
370
+ async def process_query_streaming(query: str, file_upload, reports_filter: str = "", sources_filter: str = "",
371
+ subtype_filter: str = "", year_filter: str = "",
372
+ output_format: str = "structured"):
373
+ """
374
+ Unified streaming function that yields partial results
375
+
376
+ Args:
377
+ output_format: "structured" for dict format, "gradio" for plain text format
378
+ """
379
+ file_content = None
380
+ filename = None
381
+
382
+ if file_upload is not None:
383
+ try:
384
+ with open(file_upload.name, 'rb') as f:
385
+ file_content = f.read()
386
+ filename = os.path.basename(file_upload.name)
387
+ logger.info(f"File uploaded: {filename}, size: {len(file_content)} bytes")
388
+ except Exception as e:
389
+ logger.error(f"Error reading uploaded file: {str(e)}")
390
+ if output_format == "structured":
391
+ yield {"type": "error", "content": f"Error reading file: {str(e)}"}
392
+ else:
393
+ yield f"Error reading file: {str(e)}"
394
+ return
395
+
396
+ start_time = datetime.now()
397
+ session_id = f"gradio_{start_time.strftime('%Y%m%d_%H%M%S')}"
398
+
399
+ try:
400
+ # Process ingestion first (non-streaming)
401
+ initial_state = {
402
+ "query": query,
403
+ "context": "",
404
+ "ingestor_context": "",
405
+ "result": "",
406
+ "sources": [],
407
+ "reports_filter": reports_filter or "",
408
+ "sources_filter": sources_filter or "",
409
+ "subtype_filter": subtype_filter or "",
410
+ "year_filter": year_filter or "",
411
+ "file_content": file_content,
412
+ "filename": filename,
413
+ "file_type": "unknown",
414
+ "workflow_type": "standard",
415
+ "metadata": {
416
+ "session_id": session_id,
417
+ "start_time": start_time.isoformat(),
418
+ "has_file_attachment": file_content is not None
419
+ }
420
+ }
421
+
422
+ # Detect file type - merge the returned state with initial state
423
+ state_after_detect = {**initial_state, **detect_file_type_node(initial_state)}
424
+
425
+ # Ingest if file provided - merge the returned state
426
+ state_after_ingest = {**state_after_detect, **ingest_node(state_after_detect)}
427
+
428
+ # Route workflow
429
+ workflow_type = route_workflow(state_after_ingest)
430
+
431
+ if workflow_type == "geojson_direct":
432
+ # For GeoJSON, return direct result
433
+ final_state = geojson_direct_result_node(state_after_ingest)
434
+ if output_format == "structured":
435
+ yield {"type": "data", "content": final_state["result"]}
436
+ yield {"type": "end", "content": ""}
437
+ else:
438
+ yield final_state["result"]
439
+ else:
440
+ # For standard workflow, retrieve first - merge the returned state
441
+ state_after_retrieve = {**state_after_ingest, **retrieve_node(state_after_ingest)}
442
+
443
+ # Initialize variables for both output formats
444
+ sources_collected = None
445
+ accumulated_response = "" if output_format == "gradio" else None
446
+
447
+ # Then stream generation
448
+ async for partial_state in generate_node_streaming(state_after_retrieve):
449
+ if "result" in partial_state:
450
+ if output_format == "structured":
451
+ yield {"type": "data", "content": partial_state["result"]}
452
+ else:
453
+ # Accumulate the content and yield the full accumulated response
454
+ accumulated_response += partial_state["result"]
455
+ yield accumulated_response
456
+
457
+ # Collect sources for later
458
+ if "sources" in partial_state:
459
+ sources_collected = partial_state["sources"]
460
+
461
+ # Handle sources based on output format
462
+ if sources_collected:
463
+ if output_format == "structured":
464
+ yield {"type": "sources", "content": sources_collected}
465
+ else:
466
+ # Append sources to accumulated response
467
+ sources_text = "\n\n**Sources:**\n"
468
+ for i, source in enumerate(sources_collected, 1):
469
+ if isinstance(source, dict):
470
+ title = source.get('title', 'Unknown')
471
+ link = source.get('link', '#')
472
+ sources_text += f"{i}. [{title}]({link})\n"
473
+ else:
474
+ sources_text += f"{i}. {source}\n"
475
+
476
+ accumulated_response += sources_text
477
+ yield accumulated_response
478
+
479
+ if output_format == "structured":
480
+ yield {"type": "end", "content": ""}
481
+
482
+ except Exception as e:
483
+ logger.error(f"Streaming pipeline failed: {str(e)}")
484
+ if output_format == "structured":
485
+ yield {"type": "error", "content": f"Error: {str(e)}"}
486
+ else:
487
+ yield f"Error: {str(e)}"