GraphGen / graphgen /graphgen.py
github-actions[bot]
Auto-sync from demo at Thu Oct 16 11:36:22 UTC 2025
2a0edfe
raw
history blame
9.12 kB
import asyncio
import os
import time
from dataclasses import dataclass
from typing import Dict, cast
import gradio as gr
from graphgen.bases.base_storage import StorageNameSpace
from graphgen.bases.datatypes import Chunk
from graphgen.models import (
JsonKVStorage,
JsonListStorage,
NetworkXStorage,
OpenAIClient,
Tokenizer,
)
from graphgen.operators import (
build_kg,
chunk_documents,
generate_qas,
judge_statement,
partition_kg,
quiz,
read_files,
search_all,
)
from graphgen.utils import async_to_sync_method, compute_content_hash, logger
sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@dataclass
class GraphGen:
unique_id: int = int(time.time())
working_dir: str = os.path.join(sys_path, "cache")
# llm
tokenizer_instance: Tokenizer = None
synthesizer_llm_client: OpenAIClient = None
trainee_llm_client: OpenAIClient = None
# webui
progress_bar: gr.Progress = None
def __post_init__(self):
self.tokenizer_instance: Tokenizer = self.tokenizer_instance or Tokenizer(
model_name=os.getenv("TOKENIZER_MODEL")
)
self.synthesizer_llm_client: OpenAIClient = (
self.synthesizer_llm_client
or OpenAIClient(
model_name=os.getenv("SYNTHESIZER_MODEL"),
api_key=os.getenv("SYNTHESIZER_API_KEY"),
base_url=os.getenv("SYNTHESIZER_BASE_URL"),
tokenizer=self.tokenizer_instance,
)
)
self.trainee_llm_client: OpenAIClient = self.trainee_llm_client or OpenAIClient(
model_name=os.getenv("TRAINEE_MODEL"),
api_key=os.getenv("TRAINEE_API_KEY"),
base_url=os.getenv("TRAINEE_BASE_URL"),
tokenizer=self.tokenizer_instance,
)
self.full_docs_storage: JsonKVStorage = JsonKVStorage(
self.working_dir, namespace="full_docs"
)
self.text_chunks_storage: JsonKVStorage = JsonKVStorage(
self.working_dir, namespace="text_chunks"
)
self.graph_storage: NetworkXStorage = NetworkXStorage(
self.working_dir, namespace="graph"
)
self.search_storage: JsonKVStorage = JsonKVStorage(
self.working_dir, namespace="search"
)
self.rephrase_storage: JsonKVStorage = JsonKVStorage(
self.working_dir, namespace="rephrase"
)
self.qa_storage: JsonListStorage = JsonListStorage(
os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"),
namespace="qa",
)
@async_to_sync_method
async def insert(self, read_config: Dict, split_config: Dict):
"""
insert chunks into the graph
"""
# Step 1: Read files
data = read_files(read_config["input_file"])
if len(data) == 0:
logger.warning("No data to process")
return
# TODO: configurable whether to use coreference resolution
# Step 2: Split chunks and filter existing ones
assert isinstance(data, list) and isinstance(data[0], dict)
new_docs = {
compute_content_hash(doc["content"], prefix="doc-"): {
"content": doc["content"]
}
for doc in data
}
_add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys()))
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
if len(new_docs) == 0:
logger.warning("All docs are already in the storage")
return
logger.info("[New Docs] inserting %d docs", len(new_docs))
inserting_chunks = await chunk_documents(
new_docs,
split_config["chunk_size"],
split_config["chunk_overlap"],
self.tokenizer_instance,
self.progress_bar,
)
_add_chunk_keys = await self.text_chunks_storage.filter_keys(
list(inserting_chunks.keys())
)
inserting_chunks = {
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
}
if len(inserting_chunks) == 0:
logger.warning("All chunks are already in the storage")
return
logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks))
await self.full_docs_storage.upsert(new_docs)
await self.text_chunks_storage.upsert(inserting_chunks)
# Step 3: Extract entities and relations from chunks
logger.info("[Entity and Relation Extraction]...")
_add_entities_and_relations = await build_kg(
llm_client=self.synthesizer_llm_client,
kg_instance=self.graph_storage,
chunks=[
Chunk(id=k, content=v["content"]) for k, v in inserting_chunks.items()
],
progress_bar=self.progress_bar,
)
if not _add_entities_and_relations:
logger.warning("No entities or relations extracted")
return
await self._insert_done()
return _add_entities_and_relations
async def _insert_done(self):
tasks = []
for storage_instance in [
self.full_docs_storage,
self.text_chunks_storage,
self.graph_storage,
self.search_storage,
]:
if storage_instance is None:
continue
tasks.append(cast(StorageNameSpace, storage_instance).index_done_callback())
await asyncio.gather(*tasks)
@async_to_sync_method
async def search(self, search_config: Dict):
logger.info(
"Search is %s", "enabled" if search_config["enabled"] else "disabled"
)
if search_config["enabled"]:
logger.info("[Search] %s ...", ", ".join(search_config["search_types"]))
all_nodes = await self.graph_storage.get_all_nodes()
all_nodes_names = [node[0] for node in all_nodes]
new_search_entities = await self.full_docs_storage.filter_keys(
all_nodes_names
)
logger.info(
"[Search] Found %d entities to search", len(new_search_entities)
)
_add_search_data = await search_all(
search_types=search_config["search_types"],
search_entities=new_search_entities,
)
if _add_search_data:
await self.search_storage.upsert(_add_search_data)
logger.info("[Search] %d entities searched", len(_add_search_data))
# Format search results for inserting
search_results = []
for _, search_data in _add_search_data.items():
search_results.extend(
[
{"content": search_data[key]}
for key in list(search_data.keys())
]
)
# TODO: fix insert after search
await self.insert()
@async_to_sync_method
async def quiz_and_judge(self, quiz_and_judge_config: Dict):
if quiz_and_judge_config is None or not quiz_and_judge_config.get(
"enabled", False
):
logger.warning("Quiz and Judge is not used in this pipeline.")
return
max_samples = quiz_and_judge_config["quiz_samples"]
await quiz(
self.synthesizer_llm_client,
self.graph_storage,
self.rephrase_storage,
max_samples,
)
# TODO: assert trainee_llm_client is valid before judge
re_judge = quiz_and_judge_config["re_judge"]
_update_relations = await judge_statement(
self.trainee_llm_client,
self.graph_storage,
self.rephrase_storage,
re_judge,
)
await self.rephrase_storage.index_done_callback()
await _update_relations.index_done_callback()
@async_to_sync_method
async def generate(self, partition_config: Dict, generate_config: Dict):
# Step 1: partition the graph
batches = await partition_kg(
self.graph_storage, self.tokenizer_instance, partition_config
)
# Step 2: generate QA pairs
results = await generate_qas(
self.synthesizer_llm_client,
batches,
generate_config,
progress_bar=self.progress_bar,
)
if not results:
logger.warning("No QA pairs generated")
return
# Step 3: store the generated QA pairs
await self.qa_storage.upsert(results)
await self.qa_storage.index_done_callback()
@async_to_sync_method
async def clear(self):
await self.full_docs_storage.drop()
await self.text_chunks_storage.drop()
await self.search_storage.drop()
await self.graph_storage.clear()
await self.rephrase_storage.drop()
await self.qa_storage.drop()
logger.info("All caches are cleared")