| """ |
| Custom Shell Toolkit with Base Directory Support |
| |
| This toolkit provides shell command execution constrained to a specific base directory, |
| preventing agents from navigating outside their assigned working directory. |
| """ |
|
|
| import os |
| import subprocess |
| from pathlib import Path |
| from typing import List, Optional |
| from agno.tools import Toolkit |
| from agno.utils.log import logger |
|
|
|
|
| class RestrictedShellTools(Toolkit): |
| """ |
| Shell toolkit that restricts command execution to a specific base directory. |
| |
| This ensures agents cannot navigate outside their assigned working directory, |
| solving the issue of files being saved in wrong locations. |
| """ |
| |
| def __init__(self, base_dir: Optional[Path] = None, **kwargs): |
| """ |
| Initialize the restricted shell toolkit. |
| |
| Args: |
| base_dir: Base directory to constrain all shell operations to |
| **kwargs: Additional arguments passed to parent Toolkit |
| """ |
| self.base_dir = Path(base_dir) if base_dir else Path.cwd() |
| |
| |
| self.base_dir.mkdir(parents=True, exist_ok=True) |
| |
| |
| super().__init__( |
| name="restricted_shell_tools", |
| tools=[self.run_shell_command], |
| **kwargs |
| ) |
| |
| logger.info(f"RestrictedShellTools initialized with base_dir: {self.base_dir}") |
| |
| def run_shell_command(self, command: str, timeout: int = 30) -> str: |
| """ |
| Runs a shell command in the constrained base directory. |
| |
| Args: |
| command (str): The shell command to execute |
| timeout (int): Maximum execution time in seconds |
| |
| Returns: |
| str: The output of the command or error message |
| """ |
| try: |
| |
| logger.info(f"Executing shell command in {self.base_dir}: {command}") |
| |
| |
| original_cwd = os.getcwd() |
| |
| try: |
| |
| os.chdir(self.base_dir) |
| |
| |
| result = subprocess.run( |
| command, |
| shell=True, |
| capture_output=True, |
| text=True, |
| timeout=timeout, |
| cwd=str(self.base_dir) |
| ) |
| |
| |
| logger.debug(f"Command executed with return code: {result.returncode}") |
| |
| if result.returncode != 0: |
| error_msg = f"Command failed with return code {result.returncode}\nSTDERR: {result.stderr}\nSTDOUT: {result.stdout}" |
| logger.warning(error_msg) |
| return error_msg |
| |
| |
| output = result.stdout.strip() |
| logger.debug(f"Command output: {output[:200]}{'...' if len(output) > 200 else ''}") |
| return output |
| |
| finally: |
| |
| os.chdir(original_cwd) |
| |
| except subprocess.TimeoutExpired: |
| error_msg = f"Command timed out after {timeout} seconds: {command}" |
| logger.error(error_msg) |
| return error_msg |
| |
| except Exception as e: |
| error_msg = f"Error executing command '{command}': {str(e)}" |
| logger.error(error_msg) |
| return error_msg |
| |
| def get_current_directory(self) -> str: |
| """ |
| Returns the current base directory path. |
| |
| Returns: |
| str: Absolute path of the base directory |
| """ |
| return str(self.base_dir.absolute()) |
| |
| def list_directory_contents(self) -> str: |
| """ |
| Lists the contents of the base directory. |
| |
| Returns: |
| str: Directory listing |
| """ |
| return self.run_shell_command("ls -la") |
| |
| def check_file_exists(self, filename: str) -> str: |
| """ |
| Checks if a file exists in the base directory. |
| |
| Args: |
| filename (str): Name of the file to check |
| |
| Returns: |
| str: Result of the check |
| """ |
| file_path = self.base_dir / filename |
| if file_path.exists(): |
| return f"File '{filename}' exists in {self.base_dir}" |
| else: |
| return f"File '{filename}' does not exist in {self.base_dir}" |