Spaces:
Running
Running
File size: 3,202 Bytes
fb9c306 acd7cf4 fb9c306 acd7cf4 817f16e acd7cf4 fb9c306 acd7cf4 fb9c306 acd7cf4 fb9c306 acd7cf4 fb9c306 acd7cf4 fb9c306 acd7cf4 817f16e fb9c306 bda6eda 817f16e bda6eda fb9c306 817f16e fb9c306 acd7cf4 fb9c306 817f16e acd7cf4 817f16e fb9c306 817f16e fb9c306 817f16e fb9c306 817f16e fb9c306 817f16e 1f2cf0b 817f16e fb9c306 817f16e acd7cf4 bda6eda fb9c306 acd7cf4 fb9c306 acd7cf4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import argparse
import os
import time
from importlib.resources import files
import yaml
from dotenv import load_dotenv
from graphgen.graphgen import GraphGen
from graphgen.utils import logger, set_logger
sys_path = os.path.abspath(os.path.dirname(__file__))
load_dotenv()
def set_working_dir(folder):
os.makedirs(folder, exist_ok=True)
def save_config(config_path, global_config):
if not os.path.exists(os.path.dirname(config_path)):
os.makedirs(os.path.dirname(config_path))
with open(config_path, "w", encoding="utf-8") as config_file:
yaml.dump(
global_config, config_file, default_flow_style=False, allow_unicode=True
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_file",
help="Config parameters for GraphGen.",
default=files("graphgen").joinpath("configs", "aggregated_config.yaml"),
type=str,
)
parser.add_argument(
"--output_dir",
help="Output directory for GraphGen.",
default=sys_path,
required=True,
type=str,
)
args = parser.parse_args()
working_dir = args.output_dir
with open(args.config_file, "r", encoding="utf-8") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
mode = config["generate"]["mode"]
unique_id = int(time.time())
output_path = os.path.join(working_dir, "data", "graphgen", f"{unique_id}")
set_working_dir(output_path)
set_logger(
os.path.join(output_path, f"{unique_id}_{mode}.log"),
if_stream=True,
)
logger.info(
"GraphGen with unique ID %s logging to %s",
unique_id,
os.path.join(working_dir, f"{unique_id}_{mode}.log"),
)
graph_gen = GraphGen(unique_id=unique_id, working_dir=working_dir)
graph_gen.insert(read_config=config["read"], split_config=config["split"])
graph_gen.search(search_config=config["search"])
# Use pipeline according to the output data type
if mode in ["atomic", "aggregated", "multi_hop"]:
logger.info("Generation mode set to '%s'. Start generation.", mode)
if "quiz_and_judge" in config and config["quiz_and_judge"]["enabled"]:
graph_gen.quiz_and_judge(quiz_and_judge_config=config["quiz_and_judge"])
else:
logger.warning(
"Quiz and Judge strategy is disabled. Edge sampling falls back to random."
)
assert (
config["partition"]["method"] == "ece"
and "method_params" in config["partition"]
), "Only ECE partition with edge sampling is supported."
config["partition"]["method_params"]["edge_sampling"] = "random"
elif mode == "cot":
logger.info("Generation mode set to 'cot'. Start generation.")
else:
raise ValueError(f"Unsupported output data type: {mode}")
graph_gen.generate(
partition_config=config["partition"],
generate_config=config["generate"],
)
save_config(os.path.join(output_path, "config.yaml"), config)
logger.info("GraphGen completed successfully. Data saved to %s", output_path)
if __name__ == "__main__":
main()
|