File size: 3,202 Bytes
fb9c306
acd7cf4
 
 
fb9c306
acd7cf4
 
 
817f16e
 
acd7cf4
 
 
 
 
fb9c306
acd7cf4
 
 
fb9c306
acd7cf4
 
 
fb9c306
 
 
 
 
acd7cf4
 
 
fb9c306
 
 
 
 
 
 
 
 
 
 
 
 
acd7cf4
 
 
 
 
fb9c306
acd7cf4
 
817f16e
fb9c306
bda6eda
817f16e
bda6eda
 
fb9c306
817f16e
fb9c306
acd7cf4
fb9c306
 
 
817f16e
acd7cf4
 
817f16e
fb9c306
817f16e
fb9c306
817f16e
fb9c306
 
817f16e
 
 
 
fb9c306
 
 
 
817f16e
 
1f2cf0b
817f16e
 
 
 
fb9c306
817f16e
 
 
 
 
 
acd7cf4
bda6eda
fb9c306
acd7cf4
 
fb9c306
acd7cf4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import argparse
import os
import time
from importlib.resources import files

import yaml
from dotenv import load_dotenv

from graphgen.graphgen import GraphGen
from graphgen.utils import logger, set_logger

sys_path = os.path.abspath(os.path.dirname(__file__))

load_dotenv()


def set_working_dir(folder):
    os.makedirs(folder, exist_ok=True)


def save_config(config_path, global_config):
    if not os.path.exists(os.path.dirname(config_path)):
        os.makedirs(os.path.dirname(config_path))
    with open(config_path, "w", encoding="utf-8") as config_file:
        yaml.dump(
            global_config, config_file, default_flow_style=False, allow_unicode=True
        )


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config_file",
        help="Config parameters for GraphGen.",
        default=files("graphgen").joinpath("configs", "aggregated_config.yaml"),
        type=str,
    )
    parser.add_argument(
        "--output_dir",
        help="Output directory for GraphGen.",
        default=sys_path,
        required=True,
        type=str,
    )

    args = parser.parse_args()

    working_dir = args.output_dir

    with open(args.config_file, "r", encoding="utf-8") as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    mode = config["generate"]["mode"]
    unique_id = int(time.time())

    output_path = os.path.join(working_dir, "data", "graphgen", f"{unique_id}")
    set_working_dir(output_path)

    set_logger(
        os.path.join(output_path, f"{unique_id}_{mode}.log"),
        if_stream=True,
    )
    logger.info(
        "GraphGen with unique ID %s logging to %s",
        unique_id,
        os.path.join(working_dir, f"{unique_id}_{mode}.log"),
    )

    graph_gen = GraphGen(unique_id=unique_id, working_dir=working_dir)

    graph_gen.insert(read_config=config["read"], split_config=config["split"])

    graph_gen.search(search_config=config["search"])

    # Use pipeline according to the output data type
    if mode in ["atomic", "aggregated", "multi_hop"]:
        logger.info("Generation mode set to '%s'. Start generation.", mode)
        if "quiz_and_judge" in config and config["quiz_and_judge"]["enabled"]:
            graph_gen.quiz_and_judge(quiz_and_judge_config=config["quiz_and_judge"])
        else:
            logger.warning(
                "Quiz and Judge strategy is disabled. Edge sampling falls back to random."
            )
            assert (
                config["partition"]["method"] == "ece"
                and "method_params" in config["partition"]
            ), "Only ECE partition with edge sampling is supported."
            config["partition"]["method_params"]["edge_sampling"] = "random"
    elif mode == "cot":
        logger.info("Generation mode set to 'cot'. Start generation.")
    else:
        raise ValueError(f"Unsupported output data type: {mode}")

    graph_gen.generate(
        partition_config=config["partition"],
        generate_config=config["generate"],
    )

    save_config(os.path.join(output_path, "config.yaml"), config)
    logger.info("GraphGen completed successfully. Data saved to %s", output_path)


if __name__ == "__main__":
    main()