| |
| import inspect |
| from typing import Dict, List, Optional, Any |
|
|
| from ..core.module import BaseModule |
|
|
| ALLOWED_TYPES = ["string", "number", "integer", "boolean", "object", "array"] |
|
|
|
|
| class Tool(BaseModule): |
| name: str |
| description: str |
| inputs: Dict[str, Dict[str, Any]] |
| required: Optional[List[str]] = None |
|
|
| """ |
| inputs: {"input_name": {"type": "string", "description": "input description"}, ...} |
| """ |
|
|
| def __init_subclass__(cls): |
| super().__init_subclass__() |
| cls.validate_attributes() |
|
|
| def get_tool_schema(self) -> Dict: |
| return { |
| "type": "function", |
| "function": { |
| "name": self.name, |
| "description": self.description, |
| "parameters": { |
| "type": "object", |
| "properties": self.inputs, |
| "required": self.required |
| } |
| } |
| } |
|
|
| @classmethod |
| def validate_attributes(cls): |
| required_attributes = { |
| "name": str, |
| "description": str, |
| "inputs": dict |
| } |
|
|
| json_to_python = { |
| "string": str, |
| "integer": int, |
| "number": float, |
| "boolean": bool, |
| "object": dict, |
| "array": list, |
| } |
| |
| for attr, attr_type in required_attributes.items(): |
| if not hasattr(cls, attr): |
| raise ValueError(f"Attribute {attr} is required") |
| if not isinstance(getattr(cls, attr), attr_type): |
| raise ValueError(f"Attribute {attr} must be of type {attr_type}") |
|
|
| for input_name, input_content in cls.inputs.items(): |
| if not isinstance(input_content, dict): |
| raise ValueError(f"Input '{input_name}' must be a dictionary") |
| if "type" not in input_content or "description" not in input_content: |
| raise ValueError(f"Input '{input_name}' must have 'type' and 'description'") |
| if input_content["type"] not in ALLOWED_TYPES: |
| raise ValueError(f"Input '{input_name}' must have a valid type, should be one of {ALLOWED_TYPES}") |
| |
| call_signature = inspect.signature(cls.__call__) |
| if input_name not in call_signature.parameters: |
| raise ValueError(f"Input '{input_name}' is not found in __call__") |
| if call_signature.parameters[input_name].annotation != json_to_python[input_content["type"]]: |
| raise ValueError(f"Input '{input_name}' has a type mismatch in __call__") |
|
|
| if cls.required: |
| for required_input in cls.required: |
| if required_input not in cls.inputs: |
| raise ValueError(f"Required input '{required_input}' is not found in inputs") |
| |
| def __call__(self, **kwargs): |
| raise NotImplementedError("All tools must implement __call__") |
|
|
| class Toolkit(BaseModule): |
| name: str |
| tools: List[Tool] |
|
|
| def get_tool_names(self) -> List[str]: |
| return [tool.name for tool in self.tools] |
|
|
| def get_tool_descriptions(self) -> List[str]: |
| return [tool.description for tool in self.tools] |
|
|
| def get_tool_inputs(self) -> List[Dict]: |
| return [tool.inputs for tool in self.tools] |
|
|
| def add_tool(self, tool: Tool): |
| self.tools.append(tool) |
|
|
| def remove_tool(self, tool_name: str): |
| self.tools = [tool for tool in self.tools if tool.name != tool_name] |
|
|
| def get_tool(self, tool_name: str) -> Tool: |
| for tool in self.tools: |
| if tool.name == tool_name: |
| return tool |
| raise ValueError(f"Tool '{tool_name}' not found") |
| |
| def get_tools(self) -> List[Tool]: |
| return self.tools |
| |
| def get_tool_schemas(self) -> List[Dict]: |
| return [tool.get_tool_schema() for tool in self.tools] |
| |