Spaces:
Running
Running
| import asyncio | |
| import os | |
| import time | |
| from dataclasses import dataclass | |
| from typing import Dict, cast | |
| import gradio as gr | |
| from graphgen.bases.base_storage import StorageNameSpace | |
| from graphgen.bases.datatypes import Chunk | |
| from graphgen.models import ( | |
| JsonKVStorage, | |
| JsonListStorage, | |
| NetworkXStorage, | |
| OpenAIClient, | |
| Tokenizer, | |
| ) | |
| from graphgen.operators import ( | |
| build_kg, | |
| chunk_documents, | |
| generate_qas, | |
| judge_statement, | |
| partition_kg, | |
| quiz, | |
| read_files, | |
| search_all, | |
| ) | |
| from graphgen.utils import async_to_sync_method, compute_content_hash, logger | |
| sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | |
| class GraphGen: | |
| unique_id: int = int(time.time()) | |
| working_dir: str = os.path.join(sys_path, "cache") | |
| # llm | |
| tokenizer_instance: Tokenizer = None | |
| synthesizer_llm_client: OpenAIClient = None | |
| trainee_llm_client: OpenAIClient = None | |
| # webui | |
| progress_bar: gr.Progress = None | |
| def __post_init__(self): | |
| self.tokenizer_instance: Tokenizer = self.tokenizer_instance or Tokenizer( | |
| model_name=os.getenv("TOKENIZER_MODEL") | |
| ) | |
| self.synthesizer_llm_client: OpenAIClient = ( | |
| self.synthesizer_llm_client | |
| or OpenAIClient( | |
| model_name=os.getenv("SYNTHESIZER_MODEL"), | |
| api_key=os.getenv("SYNTHESIZER_API_KEY"), | |
| base_url=os.getenv("SYNTHESIZER_BASE_URL"), | |
| tokenizer=self.tokenizer_instance, | |
| ) | |
| ) | |
| self.trainee_llm_client: OpenAIClient = self.trainee_llm_client or OpenAIClient( | |
| model_name=os.getenv("TRAINEE_MODEL"), | |
| api_key=os.getenv("TRAINEE_API_KEY"), | |
| base_url=os.getenv("TRAINEE_BASE_URL"), | |
| tokenizer=self.tokenizer_instance, | |
| ) | |
| self.full_docs_storage: JsonKVStorage = JsonKVStorage( | |
| self.working_dir, namespace="full_docs" | |
| ) | |
| self.text_chunks_storage: JsonKVStorage = JsonKVStorage( | |
| self.working_dir, namespace="text_chunks" | |
| ) | |
| self.graph_storage: NetworkXStorage = NetworkXStorage( | |
| self.working_dir, namespace="graph" | |
| ) | |
| self.search_storage: JsonKVStorage = JsonKVStorage( | |
| self.working_dir, namespace="search" | |
| ) | |
| self.rephrase_storage: JsonKVStorage = JsonKVStorage( | |
| self.working_dir, namespace="rephrase" | |
| ) | |
| self.qa_storage: JsonListStorage = JsonListStorage( | |
| os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"), | |
| namespace="qa", | |
| ) | |
| async def insert(self, read_config: Dict, split_config: Dict): | |
| """ | |
| insert chunks into the graph | |
| """ | |
| # Step 1: Read files | |
| data = read_files(read_config["input_file"]) | |
| if len(data) == 0: | |
| logger.warning("No data to process") | |
| return | |
| # TODO: configurable whether to use coreference resolution | |
| # Step 2: Split chunks and filter existing ones | |
| assert isinstance(data, list) and isinstance(data[0], dict) | |
| new_docs = { | |
| compute_content_hash(doc["content"], prefix="doc-"): { | |
| "content": doc["content"] | |
| } | |
| for doc in data | |
| } | |
| _add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys())) | |
| new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} | |
| if len(new_docs) == 0: | |
| logger.warning("All docs are already in the storage") | |
| return | |
| logger.info("[New Docs] inserting %d docs", len(new_docs)) | |
| inserting_chunks = await chunk_documents( | |
| new_docs, | |
| split_config["chunk_size"], | |
| split_config["chunk_overlap"], | |
| self.tokenizer_instance, | |
| self.progress_bar, | |
| ) | |
| _add_chunk_keys = await self.text_chunks_storage.filter_keys( | |
| list(inserting_chunks.keys()) | |
| ) | |
| inserting_chunks = { | |
| k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys | |
| } | |
| if len(inserting_chunks) == 0: | |
| logger.warning("All chunks are already in the storage") | |
| return | |
| logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks)) | |
| await self.full_docs_storage.upsert(new_docs) | |
| await self.text_chunks_storage.upsert(inserting_chunks) | |
| # Step 3: Extract entities and relations from chunks | |
| logger.info("[Entity and Relation Extraction]...") | |
| _add_entities_and_relations = await build_kg( | |
| llm_client=self.synthesizer_llm_client, | |
| kg_instance=self.graph_storage, | |
| chunks=[ | |
| Chunk(id=k, content=v["content"]) for k, v in inserting_chunks.items() | |
| ], | |
| progress_bar=self.progress_bar, | |
| ) | |
| if not _add_entities_and_relations: | |
| logger.warning("No entities or relations extracted") | |
| return | |
| await self._insert_done() | |
| return _add_entities_and_relations | |
| async def _insert_done(self): | |
| tasks = [] | |
| for storage_instance in [ | |
| self.full_docs_storage, | |
| self.text_chunks_storage, | |
| self.graph_storage, | |
| self.search_storage, | |
| ]: | |
| if storage_instance is None: | |
| continue | |
| tasks.append(cast(StorageNameSpace, storage_instance).index_done_callback()) | |
| await asyncio.gather(*tasks) | |
| async def search(self, search_config: Dict): | |
| logger.info( | |
| "Search is %s", "enabled" if search_config["enabled"] else "disabled" | |
| ) | |
| if search_config["enabled"]: | |
| logger.info("[Search] %s ...", ", ".join(search_config["search_types"])) | |
| all_nodes = await self.graph_storage.get_all_nodes() | |
| all_nodes_names = [node[0] for node in all_nodes] | |
| new_search_entities = await self.full_docs_storage.filter_keys( | |
| all_nodes_names | |
| ) | |
| logger.info( | |
| "[Search] Found %d entities to search", len(new_search_entities) | |
| ) | |
| _add_search_data = await search_all( | |
| search_types=search_config["search_types"], | |
| search_entities=new_search_entities, | |
| ) | |
| if _add_search_data: | |
| await self.search_storage.upsert(_add_search_data) | |
| logger.info("[Search] %d entities searched", len(_add_search_data)) | |
| # Format search results for inserting | |
| search_results = [] | |
| for _, search_data in _add_search_data.items(): | |
| search_results.extend( | |
| [ | |
| {"content": search_data[key]} | |
| for key in list(search_data.keys()) | |
| ] | |
| ) | |
| # TODO: fix insert after search | |
| await self.insert() | |
| async def quiz_and_judge(self, quiz_and_judge_config: Dict): | |
| if quiz_and_judge_config is None or not quiz_and_judge_config.get( | |
| "enabled", False | |
| ): | |
| logger.warning("Quiz and Judge is not used in this pipeline.") | |
| return | |
| max_samples = quiz_and_judge_config["quiz_samples"] | |
| await quiz( | |
| self.synthesizer_llm_client, | |
| self.graph_storage, | |
| self.rephrase_storage, | |
| max_samples, | |
| ) | |
| # TODO: assert trainee_llm_client is valid before judge | |
| re_judge = quiz_and_judge_config["re_judge"] | |
| _update_relations = await judge_statement( | |
| self.trainee_llm_client, | |
| self.graph_storage, | |
| self.rephrase_storage, | |
| re_judge, | |
| ) | |
| await self.rephrase_storage.index_done_callback() | |
| await _update_relations.index_done_callback() | |
| async def generate(self, partition_config: Dict, generate_config: Dict): | |
| # Step 1: partition the graph | |
| batches = await partition_kg( | |
| self.graph_storage, self.tokenizer_instance, partition_config | |
| ) | |
| # Step 2: generate QA pairs | |
| results = await generate_qas( | |
| self.synthesizer_llm_client, | |
| batches, | |
| generate_config, | |
| progress_bar=self.progress_bar, | |
| ) | |
| if not results: | |
| logger.warning("No QA pairs generated") | |
| return | |
| # Step 3: store the generated QA pairs | |
| await self.qa_storage.upsert(results) | |
| await self.qa_storage.index_done_callback() | |
| async def clear(self): | |
| await self.full_docs_storage.drop() | |
| await self.text_chunks_storage.drop() | |
| await self.search_storage.drop() | |
| await self.graph_storage.clear() | |
| await self.rephrase_storage.drop() | |
| await self.qa_storage.drop() | |
| logger.info("All caches are cleared") | |