| import json |
| from copy import deepcopy |
| from typing import Any, Dict, List |
|
|
| from flow_modules.aiflows.ChatFlowModule import ChatAtomicFlow |
|
|
|
|
| from dataclasses import dataclass |
|
|
|
|
| @dataclass |
| class Command: |
| name: str |
| description: str |
| input_args: List[str] |
|
|
| class Controller_CoderFlow(ChatAtomicFlow): |
| """Refer to: https://huggingface.co/Tachi67/JarvisFlowModule/blob/main/Controller_JarvisFlow.py""" |
| def __init__( |
| self, |
| commands: List[Command], |
| **kwargs): |
| super().__init__(**kwargs) |
| self.system_message_prompt_template = self.system_message_prompt_template.partial( |
| commands=self._build_commands_manual(commands), |
| plan="no plans yet", |
| plan_file_location="no location yet", |
| code_library="no code library yet", |
| code_library_location="no location yet", |
| logs="no logs yet", |
| ) |
| self.hint_for_model = """ |
| Make sure your response is in the following format: |
| Response Format: |
| { |
| "command": "call one of the subordinates", |
| "command_args": { |
| "arg name": "value" |
| } |
| } |
| """ |
|
|
| @staticmethod |
| def _build_commands_manual(commands: List[Command]) -> str: |
| ret = "" |
| for i, command in enumerate(commands): |
| command_input_json_schema = json.dumps( |
| {input_arg: f"YOUR_{input_arg.upper()}" for input_arg in command.input_args}) |
| ret += f"{i + 1}. {command.name}: {command.description} Input arguments (given in the JSON schema): {command_input_json_schema}\n" |
| return ret |
|
|
| def _get_content_file_location(self, input_data, content_name): |
| |
| assert "memory_files" in input_data, "memory_files not passed to Coder/Controller" |
| assert content_name in input_data["memory_files"], f"{content_name} not in memory files" |
| return input_data["memory_files"][content_name] |
|
|
| def _get_content(self, input_data, content_name): |
| |
| assert content_name in input_data, f"{content_name} not passed to Coder/Controller" |
| content = input_data[content_name] |
| if len(content) == 0: |
| content = f'No {content_name} yet' |
| return content |
|
|
| @classmethod |
| def instantiate_from_config(cls, config): |
| flow_config = deepcopy(config) |
|
|
| kwargs = {"flow_config": flow_config} |
|
|
| |
| kwargs.update(cls._set_up_prompts(flow_config)) |
|
|
| |
| kwargs.update(cls._set_up_backend(flow_config)) |
|
|
| |
| commands = flow_config["commands"] |
| commands = [ |
| Command(name, command_conf["description"], command_conf["input_args"]) for name, command_conf in |
| commands.items() |
| ] |
| kwargs.update({"commands": commands}) |
|
|
| |
| return cls(**kwargs) |
|
|
| def _update_prompts_and_input(self, input_data: Dict[str, Any]): |
| if 'goal' in input_data: |
| input_data['goal'] += self.hint_for_model |
| if 'result' in input_data: |
| input_data['result'] += self.hint_for_model |
| plan_file_location = self._get_content_file_location(input_data, "plan") |
| plan_content = self._get_content(input_data, "plan") |
| code_library_location = self._get_content_file_location(input_data, "code_library") |
| code_library_content = self._get_content(input_data, "code_library") |
| logs_content = self._get_content(input_data, "logs") |
| self.system_message_prompt_template = self.system_message_prompt_template.partial( |
| plan_file_location=plan_file_location, |
| plan=plan_content, |
| code_library_location=code_library_location, |
| code_library=code_library_content, |
| logs=logs_content |
| ) |
|
|
| def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]: |
| self._update_prompts_and_input(input_data) |
|
|
| |
| if self._is_conversation_initialized(): |
| updated_system_message_content = self._get_message(self.system_message_prompt_template, input_data) |
| self._state_update_add_chat_message(content=updated_system_message_content, |
| role=self.flow_config["system_name"]) |
|
|
| while True: |
| api_output = super().run(input_data)["api_output"].strip() |
| try: |
| start = api_output.index("{") |
| end = api_output.rindex("}") + 1 |
| json_str = api_output[start:end] |
| return json.loads(json_str) |
| except (ValueError, json.decoder.JSONDecodeError, json.JSONDecodeError): |
| updated_system_message_content = self._get_message(self.system_message_prompt_template, input_data) |
| self._state_update_add_chat_message(content=updated_system_message_content, |
| role=self.flow_config["system_name"]) |
| new_goal = "The previous respond cannot be parsed with json.loads. Next time, do not provide any comments or code blocks. Make sure your next response is purely json parsable." |
| new_input_data = input_data.copy() |
| new_input_data['result'] = new_goal |
| input_data = new_input_data |
|
|
|
|