| | import json |
| | from os import getenv |
| | from typing import Any |
| | from urllib.parse import urlencode |
| |
|
| | import httpx |
| |
|
| | from core.helper import ssrf_proxy |
| | from core.tools.entities.tool_bundle import ApiToolBundle |
| | from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType |
| | from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError |
| | from core.tools.tool.tool import Tool |
| |
|
| | API_TOOL_DEFAULT_TIMEOUT = ( |
| | int(getenv("API_TOOL_DEFAULT_CONNECT_TIMEOUT", "10")), |
| | int(getenv("API_TOOL_DEFAULT_READ_TIMEOUT", "60")), |
| | ) |
| |
|
| |
|
| | class ApiTool(Tool): |
| | api_bundle: ApiToolBundle |
| |
|
| | """ |
| | Api tool |
| | """ |
| |
|
| | def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool": |
| | """ |
| | fork a new tool with meta data |
| | |
| | :param meta: the meta data of a tool call processing, tenant_id is required |
| | :return: the new tool |
| | """ |
| | return self.__class__( |
| | identity=self.identity.model_copy() if self.identity else None, |
| | parameters=self.parameters.copy() if self.parameters else None, |
| | description=self.description.model_copy() if self.description else None, |
| | api_bundle=self.api_bundle.model_copy() if self.api_bundle else None, |
| | runtime=Tool.Runtime(**runtime), |
| | ) |
| |
|
| | def validate_credentials( |
| | self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False |
| | ) -> str: |
| | """ |
| | validate the credentials for Api tool |
| | """ |
| | |
| | headers = self.assembling_request(parameters) |
| |
|
| | if format_only: |
| | return "" |
| |
|
| | response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, parameters) |
| | |
| | return self.validate_and_parse_response(response) |
| |
|
| | def tool_provider_type(self) -> ToolProviderType: |
| | return ToolProviderType.API |
| |
|
| | def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]: |
| | headers = {} |
| | credentials = self.runtime.credentials or {} |
| |
|
| | if "auth_type" not in credentials: |
| | raise ToolProviderCredentialValidationError("Missing auth_type") |
| |
|
| | if credentials["auth_type"] == "api_key": |
| | api_key_header = "api_key" |
| |
|
| | if "api_key_header" in credentials: |
| | api_key_header = credentials["api_key_header"] |
| |
|
| | if "api_key_value" not in credentials: |
| | raise ToolProviderCredentialValidationError("Missing api_key_value") |
| | elif not isinstance(credentials["api_key_value"], str): |
| | raise ToolProviderCredentialValidationError("api_key_value must be a string") |
| |
|
| | if "api_key_header_prefix" in credentials: |
| | api_key_header_prefix = credentials["api_key_header_prefix"] |
| | if api_key_header_prefix == "basic" and credentials["api_key_value"]: |
| | credentials["api_key_value"] = f'Basic {credentials["api_key_value"]}' |
| | elif api_key_header_prefix == "bearer" and credentials["api_key_value"]: |
| | credentials["api_key_value"] = f'Bearer {credentials["api_key_value"]}' |
| | elif api_key_header_prefix == "custom": |
| | pass |
| |
|
| | headers[api_key_header] = credentials["api_key_value"] |
| |
|
| | needed_parameters = [parameter for parameter in self.api_bundle.parameters if parameter.required] |
| | for parameter in needed_parameters: |
| | if parameter.required and parameter.name not in parameters: |
| | raise ToolParameterValidationError(f"Missing required parameter {parameter.name}") |
| |
|
| | if parameter.default is not None and parameter.name not in parameters: |
| | parameters[parameter.name] = parameter.default |
| |
|
| | return headers |
| |
|
| | def validate_and_parse_response(self, response: httpx.Response) -> str: |
| | """ |
| | validate the response |
| | """ |
| | if isinstance(response, httpx.Response): |
| | if response.status_code >= 400: |
| | raise ToolInvokeError(f"Request failed with status code {response.status_code} and {response.text}") |
| | if not response.content: |
| | return "Empty response from the tool, please check your parameters and try again." |
| | try: |
| | response = response.json() |
| | try: |
| | return json.dumps(response, ensure_ascii=False) |
| | except Exception as e: |
| | return json.dumps(response) |
| | except Exception as e: |
| | return response.text |
| | else: |
| | raise ValueError(f"Invalid response type {type(response)}") |
| |
|
| | @staticmethod |
| | def get_parameter_value(parameter, parameters): |
| | if parameter["name"] in parameters: |
| | return parameters[parameter["name"]] |
| | elif parameter.get("required", False): |
| | raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}") |
| | else: |
| | return (parameter.get("schema", {}) or {}).get("default", "") |
| |
|
| | def do_http_request( |
| | self, url: str, method: str, headers: dict[str, Any], parameters: dict[str, Any] |
| | ) -> httpx.Response: |
| | """ |
| | do http request depending on api bundle |
| | """ |
| | method = method.lower() |
| |
|
| | params = {} |
| | path_params = {} |
| | body = {} |
| | cookies = {} |
| |
|
| | |
| | for parameter in self.api_bundle.openapi.get("parameters", []): |
| | value = self.get_parameter_value(parameter, parameters) |
| | if parameter["in"] == "path": |
| | path_params[parameter["name"]] = value |
| |
|
| | elif parameter["in"] == "query": |
| | if value != "": |
| | params[parameter["name"]] = value |
| |
|
| | elif parameter["in"] == "cookie": |
| | cookies[parameter["name"]] = value |
| |
|
| | elif parameter["in"] == "header": |
| | headers[parameter["name"]] = value |
| |
|
| | |
| | if "requestBody" in self.api_bundle.openapi and self.api_bundle.openapi["requestBody"] is not None: |
| | |
| | if "content" in self.api_bundle.openapi["requestBody"]: |
| | for content_type in self.api_bundle.openapi["requestBody"]["content"]: |
| | headers["Content-Type"] = content_type |
| | body_schema = self.api_bundle.openapi["requestBody"]["content"][content_type]["schema"] |
| | required = body_schema.get("required", []) |
| | properties = body_schema.get("properties", {}) |
| | for name, property in properties.items(): |
| | if name in parameters: |
| | |
| | body[name] = self._convert_body_property_type(property, parameters[name]) |
| | elif name in required: |
| | raise ToolParameterValidationError( |
| | f"Missing required parameter {name} in operation {self.api_bundle.operation_id}" |
| | ) |
| | elif "default" in property: |
| | body[name] = property["default"] |
| | else: |
| | body[name] = None |
| | break |
| |
|
| | |
| | for name, value in path_params.items(): |
| | url = url.replace(f"{{{name}}}", f"{value}") |
| |
|
| | |
| | if "Content-Type" in headers: |
| | if headers["Content-Type"] == "application/json": |
| | body = json.dumps(body) |
| | elif headers["Content-Type"] == "application/x-www-form-urlencoded": |
| | body = urlencode(body) |
| | else: |
| | body = body |
| |
|
| | if method in {"get", "head", "post", "put", "delete", "patch"}: |
| | response = getattr(ssrf_proxy, method)( |
| | url, |
| | params=params, |
| | headers=headers, |
| | cookies=cookies, |
| | data=body, |
| | timeout=API_TOOL_DEFAULT_TIMEOUT, |
| | follow_redirects=True, |
| | ) |
| | return response |
| | else: |
| | raise ValueError(f"Invalid http method {self.method}") |
| |
|
| | def _convert_body_property_any_of( |
| | self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10 |
| | ) -> Any: |
| | if max_recursive <= 0: |
| | raise Exception("Max recursion depth reached") |
| | for option in any_of or []: |
| | try: |
| | if "type" in option: |
| | |
| | if option["type"] == "integer" or option["type"] == "int": |
| | return int(value) |
| | elif option["type"] == "number": |
| | if "." in str(value): |
| | return float(value) |
| | else: |
| | return int(value) |
| | elif option["type"] == "string": |
| | return str(value) |
| | elif option["type"] == "boolean": |
| | if str(value).lower() in {"true", "1"}: |
| | return True |
| | elif str(value).lower() in {"false", "0"}: |
| | return False |
| | else: |
| | continue |
| | elif option["type"] == "null" and not value: |
| | return None |
| | else: |
| | continue |
| | elif "anyOf" in option and isinstance(option["anyOf"], list): |
| | |
| | return self._convert_body_property_any_of(property, value, option["anyOf"], max_recursive - 1) |
| | except ValueError: |
| | continue |
| | |
| | return value |
| |
|
| | def _convert_body_property_type(self, property: dict[str, Any], value: Any) -> Any: |
| | try: |
| | if "type" in property: |
| | if property["type"] == "integer" or property["type"] == "int": |
| | return int(value) |
| | elif property["type"] == "number": |
| | |
| | if "." in str(value): |
| | return float(value) |
| | else: |
| | return int(value) |
| | elif property["type"] == "string": |
| | return str(value) |
| | elif property["type"] == "boolean": |
| | return bool(value) |
| | elif property["type"] == "null": |
| | if value is None: |
| | return None |
| | elif property["type"] == "object" or property["type"] == "array": |
| | if isinstance(value, str): |
| | try: |
| | |
| | |
| | value = value.replace("'", '"') |
| | return json.loads(value) |
| | except ValueError: |
| | return value |
| | elif isinstance(value, dict): |
| | return value |
| | else: |
| | return value |
| | else: |
| | raise ValueError(f"Invalid type {property['type']} for property {property}") |
| | elif "anyOf" in property and isinstance(property["anyOf"], list): |
| | return self._convert_body_property_any_of(property, value, property["anyOf"]) |
| | except ValueError as e: |
| | return value |
| |
|
| | def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: |
| | """ |
| | invoke http request |
| | """ |
| | |
| | headers = self.assembling_request(tool_parameters) |
| |
|
| | |
| | response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, tool_parameters) |
| |
|
| | |
| | response = self.validate_and_parse_response(response) |
| |
|
| | |
| | return self.create_text_message(response) |
| |
|