Spaces:
Running
Running
| import argparse | |
| import os | |
| import time | |
| from importlib.resources import files | |
| import yaml | |
| from dotenv import load_dotenv | |
| from .graphgen import GraphGen | |
| from .utils import logger, set_logger | |
| sys_path = os.path.abspath(os.path.dirname(__file__)) | |
| load_dotenv() | |
| def set_working_dir(folder): | |
| os.makedirs(folder, exist_ok=True) | |
| os.makedirs(os.path.join(folder, "data", "graphgen"), exist_ok=True) | |
| os.makedirs(os.path.join(folder, "logs"), exist_ok=True) | |
| def save_config(config_path, global_config): | |
| if not os.path.exists(os.path.dirname(config_path)): | |
| os.makedirs(os.path.dirname(config_path)) | |
| with open(config_path, "w", encoding="utf-8") as config_file: | |
| yaml.dump( | |
| global_config, config_file, default_flow_style=False, allow_unicode=True | |
| ) | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--config_file", | |
| help="Config parameters for GraphGen.", | |
| default=files("graphgen").joinpath("configs", "aggregated_config.yaml"), | |
| type=str, | |
| ) | |
| parser.add_argument( | |
| "--output_dir", | |
| help="Output directory for GraphGen.", | |
| default=sys_path, | |
| required=True, | |
| type=str, | |
| ) | |
| args = parser.parse_args() | |
| working_dir = args.output_dir | |
| set_working_dir(working_dir) | |
| with open(args.config_file, "r", encoding="utf-8") as f: | |
| config = yaml.load(f, Loader=yaml.FullLoader) | |
| output_data_type = config["output_data_type"] | |
| unique_id = int(time.time()) | |
| set_logger( | |
| os.path.join( | |
| working_dir, "logs", f"graphgen_{output_data_type}_{unique_id}.log" | |
| ), | |
| if_stream=True, | |
| ) | |
| logger.info( | |
| "GraphGen with unique ID %s logging to %s", | |
| unique_id, | |
| os.path.join( | |
| working_dir, "logs", f"graphgen_{output_data_type}_{unique_id}.log" | |
| ), | |
| ) | |
| graph_gen = GraphGen(working_dir=working_dir, unique_id=unique_id, config=config) | |
| graph_gen.insert() | |
| if config["search"]["enabled"]: | |
| graph_gen.search() | |
| # Use pipeline according to the output data type | |
| if output_data_type in ["atomic", "aggregated", "multi_hop"]: | |
| if "quiz_and_judge_strategy" in config and config[ | |
| "quiz_and_judge_strategy" | |
| ].get("enabled", False): | |
| graph_gen.quiz() | |
| graph_gen.judge() | |
| else: | |
| logger.warning( | |
| "Quiz and Judge strategy is disabled. Edge sampling falls back to random." | |
| ) | |
| graph_gen.traverse_strategy.edge_sampling = "random" | |
| graph_gen.traverse() | |
| elif output_data_type == "cot": | |
| graph_gen.generate_reasoning(method_params=config["method_params"]) | |
| else: | |
| raise ValueError(f"Unsupported output data type: {output_data_type}") | |
| output_path = os.path.join(working_dir, "data", "graphgen", str(unique_id)) | |
| save_config(os.path.join(output_path, f"config-{unique_id}.yaml"), config) | |
| logger.info("GraphGen completed successfully. Data saved to %s", output_path) | |
| if __name__ == "__main__": | |
| main() | |