| | |
| |
|
| | import os |
| | import re |
| | import time |
| | import traceback |
| | from typing import List |
| | from pathlib import Path |
| | from ...core.logging import logger |
| | from ...prompts.optimizers.aflow_optimizer import ( |
| | WORKFLOW_INPUT, |
| | WORKFLOW_OPTIMIZE_PROMPT, |
| | WORKFLOW_CUSTOM_USE, |
| | WORKFLOW_TEMPLATE |
| | ) |
| | from ...models.base_model import BaseLLM |
| | from ...workflow.operators import ( |
| | Operator, Custom, CustomCodeGenerate, |
| | ScEnsemble, Test, AnswerGenerate, QAScEnsemble, Programmer |
| | ) |
| |
|
| | OPERATOR_MAP = { |
| | "Custom": Custom, |
| | "CustomCodeGenerate": CustomCodeGenerate, |
| | "ScEnsemble": ScEnsemble, |
| | "Test": Test, |
| | "AnswerGenerate": AnswerGenerate, |
| | "QAScEnsemble": QAScEnsemble, |
| | "Programmer": Programmer |
| | } |
| |
|
| |
|
| | class GraphUtils: |
| |
|
| | def __init__(self, root_path: str): |
| | self.root_path = root_path |
| |
|
| | def create_round_directory(self, graph_path: str, round_number: int) -> str: |
| | directory = os.path.join(graph_path, f"round_{round_number}") |
| | os.makedirs(directory, exist_ok=True) |
| | return directory |
| |
|
| | def load_graph(self, round_number: int, workflows_path: str): |
| | workflows_path = workflows_path.replace("\\", ".").replace("/", ".") |
| | graph_module_name = f"{workflows_path}.round_{round_number}.graph" |
| | try: |
| | graph_module = __import__(graph_module_name, fromlist=[""]) |
| | graph_class = getattr(graph_module, "Workflow") |
| | return graph_class |
| | except ImportError as e: |
| | logger.info(f"Error loading graph for round {round_number}: {e}") |
| | raise |
| |
|
| | def read_graph_files(self, round_number: int, workflows_path: str): |
| | prompt_file_path = os.path.join(workflows_path, f"round_{round_number}", "prompt.py") |
| | graph_file_path = os.path.join(workflows_path, f"round_{round_number}", "graph.py") |
| |
|
| | try: |
| | with open(prompt_file_path, "r", encoding="utf-8") as file: |
| | prompt_content = file.read() |
| | with open(graph_file_path, "r", encoding="utf-8") as file: |
| | graph_content = file.read() |
| | except FileNotFoundError as e: |
| | logger.info(f"Error: File not found for round {round_number}: {e}") |
| | raise |
| | except Exception as e: |
| | logger.info(f"Error loading prompt for round {round_number}: {e}") |
| | raise |
| | return prompt_content, graph_content |
| |
|
| | def extract_solve_graph(self, graph_load: str) -> List[str]: |
| | pattern = r"class Workflow:.+" |
| | return re.findall(pattern, graph_load, re.DOTALL) |
| |
|
| | def load_operators_description(self, operators: List[str], llm: BaseLLM) -> str: |
| |
|
| | operators_description = "" |
| | for id, operator in enumerate(operators): |
| | operator_description = self._load_operator_description(id + 1, operator, llm) |
| | operators_description += f"{operator_description}\n" |
| | return operators_description |
| |
|
| | def _load_operator_description(self, id: int, operator_name: str, llm: BaseLLM) -> str: |
| | if operator_name not in OPERATOR_MAP: |
| | raise ValueError(f"Operator {operator_name} not Found in OPERATOR_MAP! Available operators: {OPERATOR_MAP.keys()}") |
| | operator: Operator = OPERATOR_MAP[operator_name](llm=llm) |
| | return f"{id}. {operator_name}: {operator.description}, with interface {operator.interface})." |
| |
|
| | def create_graph_optimize_prompt( |
| | self, |
| | experience: str, |
| | score: float, |
| | graph: str, |
| | prompt: str, |
| | operator_description: str, |
| | type: str, |
| | log_data: str, |
| | ) -> str: |
| | graph_input = WORKFLOW_INPUT.format( |
| | experience=experience, |
| | score=score, |
| | graph=graph, |
| | prompt=prompt, |
| | operator_description=operator_description, |
| | type=type, |
| | log=log_data, |
| | ) |
| | graph_system = WORKFLOW_OPTIMIZE_PROMPT.format(type=type) |
| | return graph_input + WORKFLOW_CUSTOM_USE + graph_system |
| |
|
| | def get_graph_optimize_response(self, graph_optimize_node): |
| | max_retries = 5 |
| | retries = 0 |
| |
|
| | while retries < max_retries: |
| | try: |
| | response = graph_optimize_node.instruct_content.model_dump() |
| | return response |
| | except Exception as e: |
| | retries += 1 |
| | logger.info(f"Error generating prediction: {e}. Retrying... ({retries}/{max_retries})") |
| | if retries == max_retries: |
| | logger.info("Maximum retries reached. Skipping this sample.") |
| | break |
| | traceback.print_exc() |
| | time.sleep(5) |
| | return None |
| |
|
| | def write_graph_files(self, directory: str, response: dict): |
| | |
| | graph = WORKFLOW_TEMPLATE.format(graph=response["graph"]) |
| | with open(os.path.join(directory, "graph.py"), "w", encoding="utf-8") as file: |
| | file.write(graph) |
| | with open(os.path.join(directory, "prompt.py"), "w", encoding="utf-8") as file: |
| | prompt = response["prompt"].replace("prompt_custom.", "") |
| | file.write(prompt) |
| | with open(os.path.join(directory, "__init__.py"), "w", encoding="utf-8") as file: |
| | file.write("") |
| | self.update_prompt_import(os.path.join(directory, "graph.py"), directory) |
| | |
| | def update_prompt_import(self, graph_file: str, prompt_folder: str): |
| |
|
| | project_root = Path(os.getcwd()) |
| | prompt_folder_path = Path(prompt_folder) |
| |
|
| | if not prompt_folder_path.is_absolute(): |
| | prompt_folder_full_path = Path(os.path.join(project_root, prompt_folder)) |
| | if not prompt_folder_full_path.exists(): |
| | raise ValueError(f"Prompt folder {prompt_folder_full_path} does not exist!") |
| | prompt_folder_path = prompt_folder_full_path |
| | |
| | try: |
| | relative_path = prompt_folder_path.relative_to(project_root) |
| | except ValueError: |
| | raise ValueError(f"Prompt folder {prompt_folder} must be within the project directory") |
| |
|
| | import_path = str(relative_path).replace(os.sep, ".") |
| | if import_path.startswith("."): |
| | import_path = import_path[1:] |
| | |
| | with open(graph_file, "r", encoding="utf-8") as file: |
| | graph_content = file.read() |
| |
|
| | |
| | pattern = r'import .*?\.prompt as prompt_custom' |
| | replacement = f'import {import_path}.prompt as prompt_custom' |
| | new_content = re.sub(pattern, replacement, graph_content) |
| |
|
| | with open(graph_file, "w", encoding="utf-8") as file: |
| | file.write(new_content) |
| | |
| |
|
| | |