| | from typing import Union, Dict |
| | from werkzeug.datastructures import FileStorage |
| | from ...tools.base_tool import BaseTool |
| | from ...utils import clean_ansi, get_logger |
| | from jupyter_client import BlockingKernelClient |
| | import json |
| | import os |
| | import queue |
| | import re |
| | import subprocess |
| | import sys |
| | import time |
| | import traceback |
| | from enum import Enum |
| | from ...utils.file_utils import clear_files |
| |
|
| | logger = get_logger() |
| |
|
| | root_directory = os.path.abspath(__file__) |
| | while 'infiagent' not in os.path.basename(root_directory): |
| | root_directory = os.path.dirname(root_directory) |
| |
|
| | WORK_DIR = f'{root_directory}/tmp/ci_workspace' |
| | FILE_DIR = f'{root_directory}/tmp/upload_files' |
| |
|
| |
|
| | class _Type(Enum): |
| | SUCCESS = 1 |
| | ERROR = 2 |
| | FAIL = 3 |
| |
|
| |
|
| | class PythonSandBoxToolResponse: |
| |
|
| | def __init__(self, |
| | sand_box_response: str, |
| | _type: _Type) -> None: |
| | self._sand_box_response = sand_box_response |
| | self._type = _type |
| |
|
| | @property |
| | def output_text(self): |
| | return self._format(self._sand_box_response, self._type) |
| |
|
| | @property |
| | def raw_output(self): |
| | return self._sand_box_response |
| |
|
| | @classmethod |
| | def _format(cls, sandbox_response, _type): |
| | if _type == _Type.FAIL: |
| | msg = f"\nCode execution error\n" |
| | msg += f"What happened: {sandbox_response}" |
| | else: |
| | msg = "" |
| | if _type == _Type.SUCCESS: |
| | msg += "\nSTDOUT:\n" |
| | msg += f"```python\n{clean_ansi(sandbox_response)}\n```" + "\n" |
| | elif _type == _Type.ERROR: |
| | msg += "\nSTDERR:\n" |
| | msg += f"```python\n{clean_ansi(sandbox_response)}\n```" + "\n" |
| | return msg |
| |
|
| |
|
| | class AsyncPythonSandBoxTool(BaseTool): |
| | _KERNEL_CLIENTS: Dict[int, BlockingKernelClient] = {} |
| | LAUNCH_KERNEL_PY = (f"import os\nos.chdir('{root_directory}/tmp')\nfrom ipykernel import kernelapp as " |
| | f"app\napp.launch_new_instance()") |
| |
|
| | def __init__(self, name, description, **kwargs): |
| | super().__init__(name, description, **kwargs) |
| | self._sandbox_id = None |
| |
|
| | @classmethod |
| | async def create(cls, config_data, **params): |
| | |
| | instance = cls(name=config_data['name'], description=config_data['description'], **params) |
| | return instance |
| |
|
| | @classmethod |
| | def kill_kernels(cls, sandbox_id): |
| | if sandbox_id in AsyncPythonSandBoxTool._KERNEL_CLIENTS: |
| | AsyncPythonSandBoxTool._KERNEL_CLIENTS[sandbox_id].shutdown() |
| | del AsyncPythonSandBoxTool._KERNEL_CLIENTS[sandbox_id] |
| | clear_files(os.path.join(WORK_DIR, sandbox_id)) |
| | clear_files(os.path.join(FILE_DIR, sandbox_id)) |
| |
|
| | def _start_kernel(self) -> BlockingKernelClient: |
| | connection_file = os.path.join(WORK_DIR, self.sandbox_id, f'kernel_connection_file_{self.sandbox_id}.json') |
| | launch_kernel_script = os.path.join(WORK_DIR, self.sandbox_id, f'launch_kernel_{self.sandbox_id}.py') |
| | for f in [connection_file, launch_kernel_script]: |
| | if os.path.exists(f): |
| | os.remove(f) |
| |
|
| | os.makedirs(os.path.join(WORK_DIR, self.sandbox_id), exist_ok=True) |
| | with open(launch_kernel_script, 'w') as fout: |
| | fout.write(AsyncPythonSandBoxTool.LAUNCH_KERNEL_PY) |
| |
|
| | kernel_process = subprocess.Popen([ |
| | sys.executable, |
| | launch_kernel_script, |
| | '--IPKernelApp.connection_file', |
| | connection_file, |
| | '--matplotlib=inline', |
| | '--quiet', |
| | ], |
| | cwd=WORK_DIR) |
| |
|
| | |
| | while True: |
| | if not os.path.isfile(connection_file): |
| | time.sleep(0.1) |
| | else: |
| | |
| | try: |
| | with open(connection_file, 'r') as fp: |
| | json.load(fp) |
| | break |
| | except json.JSONDecodeError: |
| | pass |
| |
|
| | |
| | kc = BlockingKernelClient(connection_file=connection_file) |
| | kc.load_connection_file() |
| | kc.start_channels() |
| | kc.wait_for_ready() |
| | return kc |
| |
|
| | async def set_sandbox_id(self, sandbox_id): |
| | self._sandbox_id = sandbox_id |
| |
|
| | @property |
| | def sandbox_id(self): |
| | """Getter for sandbox_id.""" |
| | return self._sandbox_id |
| |
|
| | async def sync_to_sandbox(self, file: Union[str, Dict, FileStorage]) -> str: |
| | return os.path.join(root_directory, f"tmp/upload_files/{self.sandbox_id}/{file.split('/')[-1]}") |
| |
|
| | @staticmethod |
| | def _input_handler(input_code: str) -> str: |
| | |
| | code_blocks = re.findall(r'```(?:python)?\s*(.*?)\s*```', input_code, re.DOTALL) |
| |
|
| | |
| | python_code_cleaned = '\n'.join(code_blocks).strip() |
| |
|
| | return python_code_cleaned |
| |
|
| | @staticmethod |
| | def _escape_ansi(line: str) -> str: |
| | ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]') |
| | return ansi_escape.sub('', line) |
| |
|
| | @staticmethod |
| | def _execute_code(kc: BlockingKernelClient, code: str) -> PythonSandBoxToolResponse: |
| | kc.wait_for_ready() |
| | kc.execute(code) |
| | result = [] |
| | state = _Type.FAIL |
| |
|
| | while True: |
| | finished = False |
| | try: |
| | msg = kc.get_iopub_msg() |
| | msg_type = msg['msg_type'] |
| | logger.info(msg_type) |
| | if msg_type == 'status': |
| | if msg['content'].get('execution_state') == 'idle': |
| | finished = True |
| | elif msg_type == 'execute_result': |
| | text = msg['content']['data'].get('text/plain', '') |
| | result.append(text) |
| | state = _Type.SUCCESS |
| | elif msg_type == 'stream': |
| | text = msg['content']['text'] |
| | result.append(text) |
| | state = _Type.SUCCESS |
| | elif msg_type == 'error': |
| | text = AsyncPythonSandBoxTool._escape_ansi('\n'.join(msg['content']['traceback'])) |
| | result.append(text) |
| | state = _Type.ERROR |
| | except queue.Empty: |
| | text = 'Timeout: Code execution exceeded the time limit.' |
| | result.append(text) |
| | state = _Type.FAIL |
| | finished = True |
| | except Exception: |
| | text = 'The code interpreter encountered an unexpected error.' |
| | result.append(text) |
| | logger.error(''.join(traceback.format_exception(*sys.exc_info()))) |
| | state = _Type.FAIL |
| | finished = True |
| | if finished: |
| | break |
| | output = '\n'.join(result) |
| | return PythonSandBoxToolResponse(sand_box_response=output, _type=state) |
| |
|
| | async def async_run(self, req: str): |
| | formatted_input = self._input_handler(req) |
| | if self.sandbox_id in AsyncPythonSandBoxTool._KERNEL_CLIENTS: |
| | kc = AsyncPythonSandBoxTool._KERNEL_CLIENTS[self.sandbox_id] |
| | else: |
| | kc = self._start_kernel() |
| | AsyncPythonSandBoxTool._KERNEL_CLIENTS[self.sandbox_id] = kc |
| |
|
| | return self._execute_code(kc, formatted_input) |
| |
|
| |
|
| |
|