github-actions[bot] commited on
Commit
37f0321
·
1 Parent(s): 1f2cf0b

Auto-sync from demo at Fri Oct 10 04:18:04 UTC 2025

Browse files
graphgen/bases/base_kg_builder.py CHANGED
@@ -10,7 +10,6 @@ from graphgen.bases.datatypes import Chunk
10
 
11
  @dataclass
12
  class BaseKGBuilder(ABC):
13
- kg_instance: BaseGraphStorage
14
  llm_client: BaseLLMClient
15
 
16
  _nodes: Dict[str, List[dict]] = field(default_factory=lambda: defaultdict(list))
@@ -18,14 +17,6 @@ class BaseKGBuilder(ABC):
18
  default_factory=lambda: defaultdict(list)
19
  )
20
 
21
- def build(self, chunks: List[Chunk]) -> None:
22
- pass
23
-
24
- @abstractmethod
25
- async def extract_all(self, chunks: List[Chunk]) -> None:
26
- """Extract nodes and edges from all chunks."""
27
- raise NotImplementedError
28
-
29
  @abstractmethod
30
  async def extract(
31
  self, chunk: Chunk
@@ -35,7 +26,18 @@ class BaseKGBuilder(ABC):
35
 
36
  @abstractmethod
37
  async def merge_nodes(
38
- self, nodes_data: Dict[str, List[dict]], kg_instance: BaseGraphStorage, llm
 
 
39
  ) -> None:
40
  """Merge extracted nodes into the knowledge graph."""
41
  raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
10
 
11
  @dataclass
12
  class BaseKGBuilder(ABC):
 
13
  llm_client: BaseLLMClient
14
 
15
  _nodes: Dict[str, List[dict]] = field(default_factory=lambda: defaultdict(list))
 
17
  default_factory=lambda: defaultdict(list)
18
  )
19
 
 
 
 
 
 
 
 
 
20
  @abstractmethod
21
  async def extract(
22
  self, chunk: Chunk
 
26
 
27
  @abstractmethod
28
  async def merge_nodes(
29
+ self,
30
+ node_data: tuple[str, List[dict]],
31
+ kg_instance: BaseGraphStorage,
32
  ) -> None:
33
  """Merge extracted nodes into the knowledge graph."""
34
  raise NotImplementedError
35
+
36
+ @abstractmethod
37
+ async def merge_edges(
38
+ self,
39
+ edges_data: tuple[Tuple[str, str], List[dict]],
40
+ kg_instance: BaseGraphStorage,
41
+ ) -> None:
42
+ """Merge extracted edges into the knowledge graph."""
43
+ raise NotImplementedError
graphgen/bases/base_llm_client.py CHANGED
@@ -57,12 +57,6 @@ class BaseLLMClient(abc.ABC):
57
  """Generate probabilities for each token in the input."""
58
  raise NotImplementedError
59
 
60
- def count_tokens(self, text: str) -> int:
61
- """Count the number of tokens in the text."""
62
- if self.tokenizer is None:
63
- raise ValueError("Tokenizer is not set. Please provide a tokenizer to use count_tokens.")
64
- return len(self.tokenizer.encode(text))
65
-
66
  @staticmethod
67
  def filter_think_tags(text: str, think_tag: str = "think") -> str:
68
  """
 
57
  """Generate probabilities for each token in the input."""
58
  raise NotImplementedError
59
 
 
 
 
 
 
 
60
  @staticmethod
61
  def filter_think_tags(text: str, think_tag: str = "think") -> str:
62
  """
graphgen/graphgen.py CHANGED
@@ -16,8 +16,8 @@ from graphgen.models import (
16
  Tokenizer,
17
  )
18
  from graphgen.operators import (
 
19
  chunk_documents,
20
- extract_kg,
21
  generate_cot,
22
  judge_statement,
23
  quiz,
@@ -146,10 +146,9 @@ class GraphGen:
146
 
147
  # Step 3: Extract entities and relations from chunks
148
  logger.info("[Entity and Relation Extraction]...")
149
- _add_entities_and_relations = await extract_kg(
150
  llm_client=self.synthesizer_llm_client,
151
  kg_instance=self.graph_storage,
152
- tokenizer_instance=self.tokenizer_instance,
153
  chunks=[
154
  Chunk(id=k, content=v["content"]) for k, v in inserting_chunks.items()
155
  ],
 
16
  Tokenizer,
17
  )
18
  from graphgen.operators import (
19
+ build_kg,
20
  chunk_documents,
 
21
  generate_cot,
22
  judge_statement,
23
  quiz,
 
146
 
147
  # Step 3: Extract entities and relations from chunks
148
  logger.info("[Entity and Relation Extraction]...")
149
+ _add_entities_and_relations = await build_kg(
150
  llm_client=self.synthesizer_llm_client,
151
  kg_instance=self.graph_storage,
 
152
  chunks=[
153
  Chunk(id=k, content=v["content"]) for k, v in inserting_chunks.items()
154
  ],
graphgen/models/__init__.py CHANGED
@@ -3,6 +3,7 @@ from .evaluate.length_evaluator import LengthEvaluator
3
  from .evaluate.mtld_evaluator import MTLDEvaluator
4
  from .evaluate.reward_evaluator import RewardEvaluator
5
  from .evaluate.uni_evaluator import UniEvaluator
 
6
  from .llm.openai_client import OpenAIClient
7
  from .llm.topk_token_model import TopkTokenModel
8
  from .reader import CsvReader, JsonlReader, JsonReader, TxtReader
 
3
  from .evaluate.mtld_evaluator import MTLDEvaluator
4
  from .evaluate.reward_evaluator import RewardEvaluator
5
  from .evaluate.uni_evaluator import UniEvaluator
6
+ from .kg_builder.light_rag_kg_builder import LightRAGKGBuilder
7
  from .llm.openai_client import OpenAIClient
8
  from .llm.topk_token_model import TopkTokenModel
9
  from .reader import CsvReader, JsonlReader, JsonReader, TxtReader
graphgen/models/kg_builder/NetworkXKGBuilder.py DELETED
@@ -1,18 +0,0 @@
1
- from dataclasses import dataclass
2
-
3
- from graphgen.bases import BaseKGBuilder
4
-
5
-
6
- @dataclass
7
- class NetworkXKGBuilder(BaseKGBuilder):
8
- def build(self, chunks):
9
- pass
10
-
11
- async def extract_all(self, chunks):
12
- pass
13
-
14
- async def extract(self, chunk):
15
- pass
16
-
17
- async def merge_nodes(self, nodes_data, kg_instance, llm):
18
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/models/kg_builder/light_rag_kg_builder.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from collections import Counter, defaultdict
3
+ from dataclasses import dataclass
4
+ from typing import Dict, List, Tuple
5
+
6
+ from graphgen.bases import BaseGraphStorage, BaseKGBuilder, BaseLLMClient, Chunk
7
+ from graphgen.templates import KG_EXTRACTION_PROMPT, KG_SUMMARIZATION_PROMPT
8
+ from graphgen.utils import (
9
+ detect_if_chinese,
10
+ detect_main_language,
11
+ handle_single_entity_extraction,
12
+ handle_single_relationship_extraction,
13
+ logger,
14
+ pack_history_conversations,
15
+ split_string_by_multi_markers,
16
+ )
17
+
18
+
19
+ @dataclass
20
+ class LightRAGKGBuilder(BaseKGBuilder):
21
+ llm_client: BaseLLMClient = None
22
+ max_loop: int = 3
23
+
24
+ async def extract(
25
+ self, chunk: Chunk
26
+ ) -> Tuple[Dict[str, List[dict]], Dict[Tuple[str, str], List[dict]]]:
27
+ """
28
+ Extract entities and relationships from a single chunk using the LLM client.
29
+ :param chunk
30
+ :return: (nodes_data, edges_data)
31
+ """
32
+ chunk_id = chunk.id
33
+ content = chunk.content
34
+
35
+ # step 1: language_detection
36
+ language = "Chinese" if detect_if_chinese(content) else "English"
37
+ KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language
38
+
39
+ hint_prompt = KG_EXTRACTION_PROMPT[language]["TEMPLATE"].format(
40
+ **KG_EXTRACTION_PROMPT["FORMAT"], input_text=content
41
+ )
42
+
43
+ # step 2: initial glean
44
+ final_result = await self.llm_client.generate_answer(hint_prompt)
45
+ logger.debug("First extraction result: %s", final_result)
46
+
47
+ # step3: iterative refinement
48
+ history = pack_history_conversations(hint_prompt, final_result)
49
+ for loop_idx in range(self.max_loop):
50
+ if_loop_result = await self.llm_client.generate_answer(
51
+ text=KG_EXTRACTION_PROMPT[language]["IF_LOOP"], history=history
52
+ )
53
+ if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
54
+ if if_loop_result != "yes":
55
+ break
56
+
57
+ glean_result = await self.llm_client.generate_answer(
58
+ text=KG_EXTRACTION_PROMPT[language]["CONTINUE"], history=history
59
+ )
60
+ logger.debug("Loop %s glean: %s", loop_idx + 1, glean_result)
61
+
62
+ history += pack_history_conversations(
63
+ KG_EXTRACTION_PROMPT[language]["CONTINUE"], glean_result
64
+ )
65
+ final_result += glean_result
66
+
67
+ # step 4: parse the final result
68
+ records = split_string_by_multi_markers(
69
+ final_result,
70
+ [
71
+ KG_EXTRACTION_PROMPT["FORMAT"]["record_delimiter"],
72
+ KG_EXTRACTION_PROMPT["FORMAT"]["completion_delimiter"],
73
+ ],
74
+ )
75
+
76
+ nodes = defaultdict(list)
77
+ edges = defaultdict(list)
78
+
79
+ for record in records:
80
+ match = re.search(r"\((.*)\)", record)
81
+ if not match:
82
+ continue
83
+ inner = match.group(1)
84
+
85
+ attributes = split_string_by_multi_markers(
86
+ inner, [KG_EXTRACTION_PROMPT["FORMAT"]["tuple_delimiter"]]
87
+ )
88
+
89
+ entity = await handle_single_entity_extraction(attributes, chunk_id)
90
+ if entity is not None:
91
+ nodes[entity["entity_name"]].append(entity)
92
+ continue
93
+
94
+ relation = await handle_single_relationship_extraction(attributes, chunk_id)
95
+ if relation is not None:
96
+ key = (relation["src_id"], relation["tgt_id"])
97
+ edges[key].append(relation)
98
+
99
+ return dict(nodes), dict(edges)
100
+
101
+ async def merge_nodes(
102
+ self,
103
+ node_data: tuple[str, List[dict]],
104
+ kg_instance: BaseGraphStorage,
105
+ ) -> None:
106
+ entity_name, node_data = node_data
107
+ entity_types = []
108
+ source_ids = []
109
+ descriptions = []
110
+
111
+ node = await kg_instance.get_node(entity_name)
112
+ if node is not None:
113
+ entity_types.append(node["entity_type"])
114
+ source_ids.extend(
115
+ split_string_by_multi_markers(node["source_id"], ["<SEP>"])
116
+ )
117
+ descriptions.append(node["description"])
118
+
119
+ # take the most frequent entity_type
120
+ entity_type = sorted(
121
+ Counter([dp["entity_type"] for dp in node_data] + entity_types).items(),
122
+ key=lambda x: x[1],
123
+ reverse=True,
124
+ )[0][0]
125
+
126
+ description = "<SEP>".join(
127
+ sorted(set([dp["description"] for dp in node_data] + descriptions))
128
+ )
129
+ description = await self._handle_kg_summary(entity_name, description)
130
+
131
+ source_id = "<SEP>".join(
132
+ set([dp["source_id"] for dp in node_data] + source_ids)
133
+ )
134
+
135
+ node_data = {
136
+ "entity_type": entity_type,
137
+ "description": description,
138
+ "source_id": source_id,
139
+ }
140
+ await kg_instance.upsert_node(entity_name, node_data=node_data)
141
+
142
+ async def merge_edges(
143
+ self,
144
+ edges_data: tuple[Tuple[str, str], List[dict]],
145
+ kg_instance: BaseGraphStorage,
146
+ ) -> None:
147
+ (src_id, tgt_id), edge_data = edges_data
148
+
149
+ source_ids = []
150
+ descriptions = []
151
+
152
+ edge = await kg_instance.get_edge(src_id, tgt_id)
153
+ if edge is not None:
154
+ source_ids.extend(
155
+ split_string_by_multi_markers(edge["source_id"], ["<SEP>"])
156
+ )
157
+ descriptions.append(edge["description"])
158
+
159
+ description = "<SEP>".join(
160
+ sorted(set([dp["description"] for dp in edge_data] + descriptions))
161
+ )
162
+ source_id = "<SEP>".join(
163
+ set([dp["source_id"] for dp in edge_data] + source_ids)
164
+ )
165
+
166
+ for insert_id in [src_id, tgt_id]:
167
+ if not await kg_instance.has_node(insert_id):
168
+ await kg_instance.upsert_node(
169
+ insert_id,
170
+ node_data={
171
+ "source_id": source_id,
172
+ "description": description,
173
+ "entity_type": "UNKNOWN",
174
+ },
175
+ )
176
+
177
+ description = await self._handle_kg_summary(
178
+ f"({src_id}, {tgt_id})", description
179
+ )
180
+
181
+ await kg_instance.upsert_edge(
182
+ src_id,
183
+ tgt_id,
184
+ edge_data={"source_id": source_id, "description": description},
185
+ )
186
+
187
+ async def _handle_kg_summary(
188
+ self,
189
+ entity_or_relation_name: str,
190
+ description: str,
191
+ max_summary_tokens: int = 200,
192
+ ) -> str:
193
+ """
194
+ Handle knowledge graph summary
195
+
196
+ :param entity_or_relation_name
197
+ :param description
198
+ :param max_summary_tokens
199
+ :return summary
200
+ """
201
+
202
+ tokenizer_instance = self.llm_client.tokenizer
203
+ language = detect_main_language(description)
204
+ if language == "en":
205
+ language = "English"
206
+ else:
207
+ language = "Chinese"
208
+ KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language
209
+
210
+ tokens = tokenizer_instance.encode(description)
211
+ if len(tokens) < max_summary_tokens:
212
+ return description
213
+
214
+ use_description = tokenizer_instance.decode(tokens[:max_summary_tokens])
215
+ prompt = KG_SUMMARIZATION_PROMPT[language]["TEMPLATE"].format(
216
+ entity_name=entity_or_relation_name,
217
+ description_list=use_description.split("<SEP>"),
218
+ **KG_SUMMARIZATION_PROMPT["FORMAT"],
219
+ )
220
+ new_description = await self.llm_client.generate_answer(prompt)
221
+ logger.info(
222
+ "Entity or relation %s summary: %s",
223
+ entity_or_relation_name,
224
+ new_description,
225
+ )
226
+ return new_description
graphgen/operators/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from graphgen.operators.build_kg.extract_kg import extract_kg
2
  from graphgen.operators.generate.generate_cot import generate_cot
3
  from graphgen.operators.search.search_all import search_all
4
 
 
1
+ from graphgen.operators.build_kg.build_kg import build_kg
2
  from graphgen.operators.generate.generate_cot import generate_cot
3
  from graphgen.operators.search.search_all import search_all
4
 
graphgen/operators/build_kg/build_kg.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from typing import List
3
+
4
+ import gradio as gr
5
+
6
+ from graphgen.bases.base_storage import BaseGraphStorage
7
+ from graphgen.bases.datatypes import Chunk
8
+ from graphgen.models import LightRAGKGBuilder, OpenAIClient
9
+ from graphgen.utils import run_concurrent
10
+
11
+
12
+ async def build_kg(
13
+ llm_client: OpenAIClient,
14
+ kg_instance: BaseGraphStorage,
15
+ chunks: List[Chunk],
16
+ progress_bar: gr.Progress = None,
17
+ ):
18
+ """
19
+ :param llm_client: Synthesizer LLM model to extract entities and relationships
20
+ :param kg_instance
21
+ :param chunks
22
+ :param progress_bar: Gradio progress bar to show the progress of the extraction
23
+ :return:
24
+ """
25
+
26
+ kg_builder = LightRAGKGBuilder(llm_client=llm_client, max_loop=3)
27
+
28
+ results = await run_concurrent(
29
+ kg_builder.extract,
30
+ chunks,
31
+ desc="[2/4]Extracting entities and relationships from chunks",
32
+ unit="chunk",
33
+ progress_bar=progress_bar,
34
+ )
35
+
36
+ nodes = defaultdict(list)
37
+ edges = defaultdict(list)
38
+ for n, e in results:
39
+ for k, v in n.items():
40
+ nodes[k].extend(v)
41
+ for k, v in e.items():
42
+ edges[tuple(sorted(k))].extend(v)
43
+
44
+ await run_concurrent(
45
+ lambda kv: kg_builder.merge_nodes(kv, kg_instance=kg_instance),
46
+ list(nodes.items()),
47
+ desc="Inserting entities into storage",
48
+ )
49
+
50
+ await run_concurrent(
51
+ lambda kv: kg_builder.merge_edges(kv, kg_instance=kg_instance),
52
+ list(edges.items()),
53
+ desc="Inserting relationships into storage",
54
+ )
55
+
56
+ return kg_instance
graphgen/operators/build_kg/extract_kg.py DELETED
@@ -1,127 +0,0 @@
1
- import re
2
- from collections import defaultdict
3
- from typing import List
4
-
5
- import gradio as gr
6
-
7
- from graphgen.bases.base_storage import BaseGraphStorage
8
- from graphgen.bases.datatypes import Chunk
9
- from graphgen.models import OpenAIClient, Tokenizer
10
- from graphgen.operators.build_kg.merge_kg import merge_edges, merge_nodes
11
- from graphgen.templates import KG_EXTRACTION_PROMPT
12
- from graphgen.utils import (
13
- detect_if_chinese,
14
- handle_single_entity_extraction,
15
- handle_single_relationship_extraction,
16
- logger,
17
- pack_history_conversations,
18
- run_concurrent,
19
- split_string_by_multi_markers,
20
- )
21
-
22
-
23
- # pylint: disable=too-many-statements
24
- async def extract_kg(
25
- llm_client: OpenAIClient,
26
- kg_instance: BaseGraphStorage,
27
- tokenizer_instance: Tokenizer,
28
- chunks: List[Chunk],
29
- progress_bar: gr.Progress = None,
30
- ):
31
- """
32
- :param llm_client: Synthesizer LLM model to extract entities and relationships
33
- :param kg_instance
34
- :param tokenizer_instance
35
- :param chunks
36
- :param progress_bar: Gradio progress bar to show the progress of the extraction
37
- :return:
38
- """
39
-
40
- async def _process_single_content(chunk: Chunk, max_loop: int = 3):
41
- chunk_id = chunk.id
42
- content = chunk.content
43
- if detect_if_chinese(content):
44
- language = "Chinese"
45
- else:
46
- language = "English"
47
- KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language
48
-
49
- hint_prompt = KG_EXTRACTION_PROMPT[language]["TEMPLATE"].format(
50
- **KG_EXTRACTION_PROMPT["FORMAT"], input_text=content
51
- )
52
-
53
- final_result = await llm_client.generate_answer(hint_prompt)
54
- logger.info("First result: %s", final_result)
55
-
56
- history = pack_history_conversations(hint_prompt, final_result)
57
- for loop_index in range(max_loop):
58
- if_loop_result = await llm_client.generate_answer(
59
- text=KG_EXTRACTION_PROMPT[language]["IF_LOOP"], history=history
60
- )
61
- if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
62
- if if_loop_result != "yes":
63
- break
64
-
65
- glean_result = await llm_client.generate_answer(
66
- text=KG_EXTRACTION_PROMPT[language]["CONTINUE"], history=history
67
- )
68
- logger.info("Loop %s glean: %s", loop_index, glean_result)
69
-
70
- history += pack_history_conversations(
71
- KG_EXTRACTION_PROMPT[language]["CONTINUE"], glean_result
72
- )
73
- final_result += glean_result
74
- if loop_index == max_loop - 1:
75
- break
76
-
77
- records = split_string_by_multi_markers(
78
- final_result,
79
- [
80
- KG_EXTRACTION_PROMPT["FORMAT"]["record_delimiter"],
81
- KG_EXTRACTION_PROMPT["FORMAT"]["completion_delimiter"],
82
- ],
83
- )
84
-
85
- nodes = defaultdict(list)
86
- edges = defaultdict(list)
87
-
88
- for record in records:
89
- record = re.search(r"\((.*)\)", record)
90
- if record is None:
91
- continue
92
- record = record.group(1) # 提取括号内的内容
93
- record_attributes = split_string_by_multi_markers(
94
- record, [KG_EXTRACTION_PROMPT["FORMAT"]["tuple_delimiter"]]
95
- )
96
-
97
- entity = await handle_single_entity_extraction(record_attributes, chunk_id)
98
- if entity is not None:
99
- nodes[entity["entity_name"]].append(entity)
100
- continue
101
- relation = await handle_single_relationship_extraction(
102
- record_attributes, chunk_id
103
- )
104
- if relation is not None:
105
- edges[(relation["src_id"], relation["tgt_id"])].append(relation)
106
- return dict(nodes), dict(edges)
107
-
108
- results = await run_concurrent(
109
- _process_single_content,
110
- chunks,
111
- desc="[2/4]Extracting entities and relationships from chunks",
112
- unit="chunk",
113
- progress_bar=progress_bar,
114
- )
115
-
116
- nodes = defaultdict(list)
117
- edges = defaultdict(list)
118
- for n, e in results:
119
- for k, v in n.items():
120
- nodes[k].extend(v)
121
- for k, v in e.items():
122
- edges[tuple(sorted(k))].extend(v)
123
-
124
- await merge_nodes(nodes, kg_instance, llm_client, tokenizer_instance)
125
- await merge_edges(edges, kg_instance, llm_client, tokenizer_instance)
126
-
127
- return kg_instance
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/operators/build_kg/merge_kg.py DELETED
@@ -1,212 +0,0 @@
1
- import asyncio
2
- from collections import Counter
3
-
4
- from tqdm.asyncio import tqdm as tqdm_async
5
-
6
- from graphgen.bases import BaseGraphStorage, BaseLLMClient
7
- from graphgen.models import Tokenizer
8
- from graphgen.templates import KG_EXTRACTION_PROMPT, KG_SUMMARIZATION_PROMPT
9
- from graphgen.utils import detect_main_language, logger
10
- from graphgen.utils.format import split_string_by_multi_markers
11
-
12
-
13
- async def _handle_kg_summary(
14
- entity_or_relation_name: str,
15
- description: str,
16
- llm_client: BaseLLMClient,
17
- tokenizer_instance: Tokenizer,
18
- max_summary_tokens: int = 200,
19
- ) -> str:
20
- """
21
- 处理实体或关系的描述信息
22
-
23
- :param entity_or_relation_name
24
- :param description
25
- :param llm_client
26
- :param tokenizer_instance
27
- :param max_summary_tokens
28
- :return: new description
29
- """
30
- language = detect_main_language(description)
31
- if language == "en":
32
- language = "English"
33
- else:
34
- language = "Chinese"
35
- KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language
36
-
37
- tokens = tokenizer_instance.encode(description)
38
- if len(tokens) < max_summary_tokens:
39
- return description
40
-
41
- use_description = tokenizer_instance.decode(tokens[:max_summary_tokens])
42
- prompt = KG_SUMMARIZATION_PROMPT[language]["TEMPLATE"].format(
43
- entity_name=entity_or_relation_name,
44
- description_list=use_description.split("<SEP>"),
45
- **KG_SUMMARIZATION_PROMPT["FORMAT"],
46
- )
47
- new_description = await llm_client.generate_answer(prompt)
48
- logger.info(
49
- "Entity or relation %s summary: %s", entity_or_relation_name, new_description
50
- )
51
- return new_description
52
-
53
-
54
- async def merge_nodes(
55
- nodes_data: dict,
56
- kg_instance: BaseGraphStorage,
57
- llm_client: BaseLLMClient,
58
- tokenizer_instance: Tokenizer,
59
- max_concurrent: int = 1000,
60
- ):
61
- """
62
- Merge nodes
63
-
64
- :param nodes_data
65
- :param kg_instance
66
- :param llm_client
67
- :param tokenizer_instance
68
- :param max_concurrent
69
- :return
70
- """
71
-
72
- semaphore = asyncio.Semaphore(max_concurrent)
73
-
74
- async def process_single_node(entity_name: str, node_data: list[dict]):
75
- async with semaphore:
76
- entity_types = []
77
- source_ids = []
78
- descriptions = []
79
-
80
- node = await kg_instance.get_node(entity_name)
81
- if node is not None:
82
- entity_types.append(node["entity_type"])
83
- source_ids.extend(
84
- split_string_by_multi_markers(node["source_id"], ["<SEP>"])
85
- )
86
- descriptions.append(node["description"])
87
-
88
- # 统计当前节点数据和已有节点数据的entity_type出现次数,取出现次数最多的entity_type
89
- entity_type = sorted(
90
- Counter([dp["entity_type"] for dp in node_data] + entity_types).items(),
91
- key=lambda x: x[1],
92
- reverse=True,
93
- )[0][0]
94
-
95
- description = "<SEP>".join(
96
- sorted(set([dp["description"] for dp in node_data] + descriptions))
97
- )
98
- description = await _handle_kg_summary(
99
- entity_name, description, llm_client, tokenizer_instance
100
- )
101
-
102
- source_id = "<SEP>".join(
103
- set([dp["source_id"] for dp in node_data] + source_ids)
104
- )
105
-
106
- node_data = {
107
- "entity_type": entity_type,
108
- "description": description,
109
- "source_id": source_id,
110
- }
111
- await kg_instance.upsert_node(entity_name, node_data=node_data)
112
- node_data["entity_name"] = entity_name
113
- return node_data
114
-
115
- logger.info("Inserting entities into storage...")
116
- entities_data = []
117
- for result in tqdm_async(
118
- asyncio.as_completed(
119
- [process_single_node(k, v) for k, v in nodes_data.items()]
120
- ),
121
- total=len(nodes_data),
122
- desc="Inserting entities into storage",
123
- unit="entity",
124
- ):
125
- try:
126
- entities_data.append(await result)
127
- except Exception as e: # pylint: disable=broad-except
128
- logger.error("Error occurred while inserting entities into storage: %s", e)
129
-
130
-
131
- async def merge_edges(
132
- edges_data: dict,
133
- kg_instance: BaseGraphStorage,
134
- llm_client: BaseLLMClient,
135
- tokenizer_instance: Tokenizer,
136
- max_concurrent: int = 1000,
137
- ):
138
- """
139
- Merge edges
140
-
141
- :param edges_data
142
- :param kg_instance
143
- :param llm_client
144
- :param tokenizer_instance
145
- :param max_concurrent
146
- :return
147
- """
148
-
149
- semaphore = asyncio.Semaphore(max_concurrent)
150
-
151
- async def process_single_edge(src_id: str, tgt_id: str, edge_data: list[dict]):
152
- async with semaphore:
153
- source_ids = []
154
- descriptions = []
155
-
156
- edge = await kg_instance.get_edge(src_id, tgt_id)
157
- if edge is not None:
158
- source_ids.extend(
159
- split_string_by_multi_markers(edge["source_id"], ["<SEP>"])
160
- )
161
- descriptions.append(edge["description"])
162
-
163
- description = "<SEP>".join(
164
- sorted(set([dp["description"] for dp in edge_data] + descriptions))
165
- )
166
- source_id = "<SEP>".join(
167
- set([dp["source_id"] for dp in edge_data] + source_ids)
168
- )
169
-
170
- for insert_id in [src_id, tgt_id]:
171
- if not await kg_instance.has_node(insert_id):
172
- await kg_instance.upsert_node(
173
- insert_id,
174
- node_data={
175
- "source_id": source_id,
176
- "description": description,
177
- "entity_type": "UNKNOWN",
178
- },
179
- )
180
-
181
- description = await _handle_kg_summary(
182
- f"({src_id}, {tgt_id})", description, llm_client, tokenizer_instance
183
- )
184
-
185
- await kg_instance.upsert_edge(
186
- src_id,
187
- tgt_id,
188
- edge_data={"source_id": source_id, "description": description},
189
- )
190
-
191
- edge_data = {"src_id": src_id, "tgt_id": tgt_id, "description": description}
192
- return edge_data
193
-
194
- logger.info("Inserting relationships into storage...")
195
- relationships_data = []
196
- for result in tqdm_async(
197
- asyncio.as_completed(
198
- [
199
- process_single_edge(src_id, tgt_id, v)
200
- for (src_id, tgt_id), v in edges_data.items()
201
- ]
202
- ),
203
- total=len(edges_data),
204
- desc="Inserting relationships into storage",
205
- unit="relationship",
206
- ):
207
- try:
208
- relationships_data.append(await result)
209
- except Exception as e: # pylint: disable=broad-except
210
- logger.error(
211
- "Error occurred while inserting relationships into storage: %s", e
212
- )