Spaces:
Running
Running
| import math | |
| import asyncio | |
| from tqdm.asyncio import tqdm as tqdm_async | |
| from graphgen.models import NetworkXStorage, OpenAIModel, JsonKVStorage | |
| from graphgen.utils import logger, yes_no_loss_entropy | |
| from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT | |
| async def judge_statement( # pylint: disable=too-many-statements | |
| trainee_llm_client: OpenAIModel, | |
| graph_storage: NetworkXStorage, | |
| rephrase_storage: JsonKVStorage, | |
| re_judge: bool = False, | |
| max_concurrent: int = 1000) -> NetworkXStorage: | |
| """ | |
| Get all edges and nodes and judge them | |
| :param trainee_llm_client: judge the statements to get comprehension loss | |
| :param graph_storage: graph storage instance | |
| :param rephrase_storage: rephrase storage instance | |
| :param re_judge: re-judge the relations | |
| :param max_concurrent: max concurrent | |
| :return: | |
| """ | |
| semaphore = asyncio.Semaphore(max_concurrent) | |
| async def _judge_single_relation( | |
| edge: tuple, | |
| ): | |
| async with semaphore: | |
| source_id = edge[0] | |
| target_id = edge[1] | |
| edge_data = edge[2] | |
| if (not re_judge) and "loss" in edge_data and edge_data["loss"] is not None: | |
| logger.info("Edge %s -> %s already judged, loss: %s, skip", source_id, target_id, edge_data["loss"]) | |
| return source_id, target_id, edge_data | |
| description = edge_data["description"] | |
| try: | |
| descriptions = await rephrase_storage.get_by_id(description) | |
| assert descriptions is not None | |
| judgements = [] | |
| gts = [gt for _, gt in descriptions] | |
| for description, gt in descriptions: | |
| judgement = await trainee_llm_client.generate_topk_per_token( | |
| STATEMENT_JUDGEMENT_PROMPT['TEMPLATE'].format(statement=description) | |
| ) | |
| judgements.append(judgement[0].top_candidates) | |
| loss = yes_no_loss_entropy(judgements, gts) | |
| logger.info("Edge %s -> %s description: %s loss: %s", source_id, target_id, description, loss) | |
| edge_data["loss"] = loss | |
| except Exception as e: # pylint: disable=broad-except | |
| logger.error("Error in judging relation %s -> %s: %s", source_id, target_id, e) | |
| logger.info("Use default loss 0.1") | |
| edge_data["loss"] = -math.log(0.1) | |
| await graph_storage.update_edge(source_id, target_id, edge_data) | |
| return source_id, target_id, edge_data | |
| edges = await graph_storage.get_all_edges() | |
| results = [] | |
| for result in tqdm_async( | |
| asyncio.as_completed([_judge_single_relation(edge) for edge in edges]), | |
| total=len(edges), | |
| desc="Judging relations" | |
| ): | |
| results.append(await result) | |
| async def _judge_single_entity( | |
| node: tuple, | |
| ): | |
| async with semaphore: | |
| node_id = node[0] | |
| node_data = node[1] | |
| if (not re_judge) and "loss" in node_data and node_data["loss"] is not None: | |
| logger.info("Node %s already judged, loss: %s, skip", node_id, node_data["loss"]) | |
| return node_id, node_data | |
| description = node_data["description"] | |
| try: | |
| descriptions = await rephrase_storage.get_by_id(description) | |
| assert descriptions is not None | |
| judgements = [] | |
| gts = [gt for _, gt in descriptions] | |
| for description, gt in descriptions: | |
| judgement = await trainee_llm_client.generate_topk_per_token( | |
| STATEMENT_JUDGEMENT_PROMPT['TEMPLATE'].format(statement=description) | |
| ) | |
| judgements.append(judgement[0].top_candidates) | |
| loss = yes_no_loss_entropy(judgements, gts) | |
| logger.info("Node %s description: %s loss: %s", node_id, description, loss) | |
| node_data["loss"] = loss | |
| except Exception as e: # pylint: disable=broad-except | |
| logger.error("Error in judging entity %s: %s", node_id, e) | |
| logger.info("Use default loss 0.1") | |
| node_data["loss"] = -math.log(0.1) | |
| await graph_storage.update_node(node_id, node_data) | |
| return node_id, node_data | |
| nodes = await graph_storage.get_all_nodes() | |
| results = [] | |
| for result in tqdm_async( | |
| asyncio.as_completed([_judge_single_entity(node) for node in nodes]), | |
| total=len(nodes), | |
| desc="Judging entities" | |
| ): | |
| results.append(await result) | |
| return graph_storage | |
| async def skip_judge_statement( | |
| graph_storage: NetworkXStorage, | |
| max_concurrent: int = 1000 | |
| ): | |
| """ | |
| Skip the judgement of the statement | |
| :param graph_storage: graph storage instance | |
| :param max_concurrent: max concurrent | |
| :return: | |
| """ | |
| semaphore = asyncio.Semaphore(max_concurrent) | |
| async def _skip_single_relation( | |
| edge: tuple, | |
| ): | |
| async with semaphore: | |
| source_id = edge[0] | |
| target_id = edge[1] | |
| edge_data = edge[2] | |
| if "loss" in edge_data and edge_data["loss"] is not None: | |
| logger.info("Edge %s -> %s already judged, loss: %s, skip", source_id, target_id, edge_data["loss"]) | |
| return source_id, target_id, edge_data | |
| edge_data["loss"] = -math.log(0.1) | |
| await graph_storage.update_edge(source_id, target_id, edge_data) | |
| return source_id, target_id, edge_data | |
| edges = await graph_storage.get_all_edges() | |
| results = [] | |
| for result in tqdm_async( | |
| asyncio.as_completed([_skip_single_relation(edge) for edge in edges]), | |
| total=len(edges), | |
| desc="Skipping judgement of relations" | |
| ): | |
| results.append(await result) | |
| async def _skip_single_entity( | |
| node: tuple, | |
| ): | |
| async with semaphore: | |
| node_id = node[0] | |
| node_data = node[1] | |
| if "loss" in node_data and node_data["loss"] is not None: | |
| logger.info("Node %s already judged, loss: %s, skip", node_id, node_data["loss"]) | |
| return node_id, node_data | |
| node_data["loss"] = -math.log(0.1) | |
| await graph_storage.update_node(node_id, node_data) | |
| return node_id, node_data | |
| nodes = await graph_storage.get_all_nodes() | |
| results = [] | |
| for result in tqdm_async( | |
| asyncio.as_completed([_skip_single_entity(node) for node in nodes]), | |
| total=len(nodes), | |
| desc="Skipping judgement of entities" | |
| ): | |
| results.append(await result) | |
| return graph_storage | |