| import json |
| import logging |
| import mimetypes |
| from collections.abc import Generator |
| from os import listdir, path |
| from threading import Lock, Thread |
| from typing import Any, Optional, Union |
|
|
| from configs import dify_config |
| from core.agent.entities import AgentToolEntity |
| from core.app.entities.app_invoke_entities import InvokeFrom |
| from core.helper.module_import_helper import load_single_subclass_from_source |
| from core.helper.position_helper import is_filtered |
| from core.model_runtime.utils.encoders import jsonable_encoder |
| from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral |
| from core.tools.entities.common_entities import I18nObject |
| from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter |
| from core.tools.errors import ToolProviderNotFoundError |
| from core.tools.provider.api_tool_provider import ApiToolProviderController |
| from core.tools.provider.builtin._positions import BuiltinToolProviderSort |
| from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController |
| from core.tools.tool.api_tool import ApiTool |
| from core.tools.tool.builtin_tool import BuiltinTool |
| from core.tools.tool.tool import Tool |
| from core.tools.tool_label_manager import ToolLabelManager |
| from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager |
| from extensions.ext_database import db |
| from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider |
| from services.tools.tools_transform_service import ToolTransformService |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class ToolManager: |
| _builtin_provider_lock = Lock() |
| _builtin_providers = {} |
| _builtin_providers_loaded = False |
| _builtin_tools_labels = {} |
|
|
| @classmethod |
| def get_builtin_provider(cls, provider: str) -> BuiltinToolProviderController: |
| """ |
| get the builtin provider |
| |
| :param provider: the name of the provider |
| :return: the provider |
| """ |
| if len(cls._builtin_providers) == 0: |
| |
| cls.load_builtin_providers_cache() |
|
|
| if provider not in cls._builtin_providers: |
| raise ToolProviderNotFoundError(f"builtin provider {provider} not found") |
|
|
| return cls._builtin_providers[provider] |
|
|
| @classmethod |
| def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool: |
| """ |
| get the builtin tool |
| |
| :param provider: the name of the provider |
| :param tool_name: the name of the tool |
| |
| :return: the provider, the tool |
| """ |
| provider_controller = cls.get_builtin_provider(provider) |
| tool = provider_controller.get_tool(tool_name) |
|
|
| return tool |
|
|
| @classmethod |
| def get_tool( |
| cls, provider_type: str, provider_id: str, tool_name: str, tenant_id: Optional[str] = None |
| ) -> Union[BuiltinTool, ApiTool]: |
| """ |
| get the tool |
| |
| :param provider_type: the type of the provider |
| :param provider_name: the name of the provider |
| :param tool_name: the name of the tool |
| |
| :return: the tool |
| """ |
| if provider_type == "builtin": |
| return cls.get_builtin_tool(provider_id, tool_name) |
| elif provider_type == "api": |
| if tenant_id is None: |
| raise ValueError("tenant id is required for api provider") |
| api_provider, _ = cls.get_api_provider_controller(tenant_id, provider_id) |
| return api_provider.get_tool(tool_name) |
| elif provider_type == "app": |
| raise NotImplementedError("app provider not implemented") |
| else: |
| raise ToolProviderNotFoundError(f"provider type {provider_type} not found") |
|
|
| @classmethod |
| def get_tool_runtime( |
| cls, |
| provider_type: str, |
| provider_id: str, |
| tool_name: str, |
| tenant_id: str, |
| invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, |
| tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, |
| ) -> Union[BuiltinTool, ApiTool]: |
| """ |
| get the tool runtime |
| |
| :param provider_type: the type of the provider |
| :param provider_name: the name of the provider |
| :param tool_name: the name of the tool |
| |
| :return: the tool |
| """ |
| if provider_type == "builtin": |
| builtin_tool = cls.get_builtin_tool(provider_id, tool_name) |
|
|
| |
| provider_controller = cls.get_builtin_provider(provider_id) |
| if not provider_controller.need_credentials: |
| return builtin_tool.fork_tool_runtime( |
| runtime={ |
| "tenant_id": tenant_id, |
| "credentials": {}, |
| "invoke_from": invoke_from, |
| "tool_invoke_from": tool_invoke_from, |
| } |
| ) |
|
|
| |
| builtin_provider: BuiltinToolProvider = ( |
| db.session.query(BuiltinToolProvider) |
| .filter( |
| BuiltinToolProvider.tenant_id == tenant_id, |
| BuiltinToolProvider.provider == provider_id, |
| ) |
| .first() |
| ) |
|
|
| if builtin_provider is None: |
| raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") |
|
|
| |
| credentials = builtin_provider.credentials |
| controller = cls.get_builtin_provider(provider_id) |
| tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller) |
|
|
| decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) |
|
|
| return builtin_tool.fork_tool_runtime( |
| runtime={ |
| "tenant_id": tenant_id, |
| "credentials": decrypted_credentials, |
| "runtime_parameters": {}, |
| "invoke_from": invoke_from, |
| "tool_invoke_from": tool_invoke_from, |
| } |
| ) |
|
|
| elif provider_type == "api": |
| if tenant_id is None: |
| raise ValueError("tenant id is required for api provider") |
|
|
| api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id) |
|
|
| |
| tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider) |
| decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) |
|
|
| return api_provider.get_tool(tool_name).fork_tool_runtime( |
| runtime={ |
| "tenant_id": tenant_id, |
| "credentials": decrypted_credentials, |
| "invoke_from": invoke_from, |
| "tool_invoke_from": tool_invoke_from, |
| } |
| ) |
| elif provider_type == "workflow": |
| workflow_provider = ( |
| db.session.query(WorkflowToolProvider) |
| .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) |
| .first() |
| ) |
|
|
| if workflow_provider is None: |
| raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") |
|
|
| controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider) |
|
|
| return controller.get_tools(user_id=None, tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( |
| runtime={ |
| "tenant_id": tenant_id, |
| "credentials": {}, |
| "invoke_from": invoke_from, |
| "tool_invoke_from": tool_invoke_from, |
| } |
| ) |
| elif provider_type == "app": |
| raise NotImplementedError("app provider not implemented") |
| else: |
| raise ToolProviderNotFoundError(f"provider type {provider_type} not found") |
|
|
| @classmethod |
| def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict): |
| """ |
| init runtime parameter |
| """ |
| parameter_value = parameters.get(parameter_rule.name) |
| if not parameter_value and parameter_value != 0: |
| |
| parameter_value = parameter_rule.default |
| if not parameter_value and parameter_rule.required: |
| raise ValueError(f"tool parameter {parameter_rule.name} not found in tool config") |
|
|
| if parameter_rule.type == ToolParameter.ToolParameterType.SELECT: |
| |
| options = [x.value for x in parameter_rule.options] |
| if parameter_value is not None and parameter_value not in options: |
| raise ValueError( |
| f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}" |
| ) |
|
|
| return parameter_rule.type.cast_value(parameter_value) |
|
|
| @classmethod |
| def get_agent_tool_runtime( |
| cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER |
| ) -> Tool: |
| """ |
| get the agent tool runtime |
| """ |
| tool_entity = cls.get_tool_runtime( |
| provider_type=agent_tool.provider_type, |
| provider_id=agent_tool.provider_id, |
| tool_name=agent_tool.tool_name, |
| tenant_id=tenant_id, |
| invoke_from=invoke_from, |
| tool_invoke_from=ToolInvokeFrom.AGENT, |
| ) |
| runtime_parameters = {} |
| parameters = tool_entity.get_all_runtime_parameters() |
| for parameter in parameters: |
| |
| if ( |
| parameter.type |
| in { |
| ToolParameter.ToolParameterType.SYSTEM_FILES, |
| ToolParameter.ToolParameterType.FILE, |
| ToolParameter.ToolParameterType.FILES, |
| } |
| and parameter.required |
| ): |
| raise ValueError(f"file type parameter {parameter.name} not supported in agent") |
|
|
| if parameter.form == ToolParameter.ToolParameterForm.FORM: |
| |
| value = cls._init_runtime_parameter(parameter, agent_tool.tool_parameters) |
| runtime_parameters[parameter.name] = value |
|
|
| |
| encryption_manager = ToolParameterConfigurationManager( |
| tenant_id=tenant_id, |
| tool_runtime=tool_entity, |
| provider_name=agent_tool.provider_id, |
| provider_type=agent_tool.provider_type, |
| identity_id=f"AGENT.{app_id}", |
| ) |
| runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) |
|
|
| tool_entity.runtime.runtime_parameters.update(runtime_parameters) |
| return tool_entity |
|
|
| @classmethod |
| def get_workflow_tool_runtime( |
| cls, |
| tenant_id: str, |
| app_id: str, |
| node_id: str, |
| workflow_tool: "ToolEntity", |
| invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, |
| ) -> Tool: |
| """ |
| get the workflow tool runtime |
| """ |
| tool_entity = cls.get_tool_runtime( |
| provider_type=workflow_tool.provider_type, |
| provider_id=workflow_tool.provider_id, |
| tool_name=workflow_tool.tool_name, |
| tenant_id=tenant_id, |
| invoke_from=invoke_from, |
| tool_invoke_from=ToolInvokeFrom.WORKFLOW, |
| ) |
| runtime_parameters = {} |
| parameters = tool_entity.get_all_runtime_parameters() |
|
|
| for parameter in parameters: |
| |
| if parameter.form == ToolParameter.ToolParameterForm.FORM: |
| value = cls._init_runtime_parameter(parameter, workflow_tool.tool_configurations) |
| runtime_parameters[parameter.name] = value |
|
|
| |
| encryption_manager = ToolParameterConfigurationManager( |
| tenant_id=tenant_id, |
| tool_runtime=tool_entity, |
| provider_name=workflow_tool.provider_id, |
| provider_type=workflow_tool.provider_type, |
| identity_id=f"WORKFLOW.{app_id}.{node_id}", |
| ) |
|
|
| if runtime_parameters: |
| runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) |
|
|
| tool_entity.runtime.runtime_parameters.update(runtime_parameters) |
| return tool_entity |
|
|
| @classmethod |
| def get_builtin_provider_icon(cls, provider: str) -> tuple[str, str]: |
| """ |
| get the absolute path of the icon of the builtin provider |
| |
| :param provider: the name of the provider |
| |
| :return: the absolute path of the icon, the mime type of the icon |
| """ |
| |
| provider_controller = cls.get_builtin_provider(provider) |
|
|
| absolute_path = path.join( |
| path.dirname(path.realpath(__file__)), |
| "provider", |
| "builtin", |
| provider, |
| "_assets", |
| provider_controller.identity.icon, |
| ) |
| |
| if not path.exists(absolute_path): |
| raise ToolProviderNotFoundError(f"builtin provider {provider} icon not found") |
|
|
| |
| mime_type, _ = mimetypes.guess_type(absolute_path) |
| mime_type = mime_type or "application/octet-stream" |
|
|
| return absolute_path, mime_type |
|
|
| @classmethod |
| def list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, None, None]: |
| |
| if cls._builtin_providers_loaded: |
| yield from list(cls._builtin_providers.values()) |
| return |
|
|
| with cls._builtin_provider_lock: |
| if cls._builtin_providers_loaded: |
| yield from list(cls._builtin_providers.values()) |
| return |
|
|
| yield from cls._list_builtin_providers() |
|
|
| @classmethod |
| def _list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, None, None]: |
| """ |
| list all the builtin providers |
| """ |
| for provider in listdir(path.join(path.dirname(path.realpath(__file__)), "provider", "builtin")): |
| if provider.startswith("__"): |
| continue |
|
|
| if path.isdir(path.join(path.dirname(path.realpath(__file__)), "provider", "builtin", provider)): |
| if provider.startswith("__"): |
| continue |
|
|
| |
| try: |
| provider_class = load_single_subclass_from_source( |
| module_name=f"core.tools.provider.builtin.{provider}.{provider}", |
| script_path=path.join( |
| path.dirname(path.realpath(__file__)), "provider", "builtin", provider, f"{provider}.py" |
| ), |
| parent_type=BuiltinToolProviderController, |
| ) |
| provider: BuiltinToolProviderController = provider_class() |
| cls._builtin_providers[provider.identity.name] = provider |
| for tool in provider.get_tools(): |
| cls._builtin_tools_labels[tool.identity.name] = tool.identity.label |
| yield provider |
|
|
| except Exception as e: |
| logger.error(f"load builtin provider {provider} error: {e}") |
| continue |
| |
| cls._builtin_providers_loaded = True |
|
|
| @classmethod |
| def load_builtin_providers_cache(cls): |
| for _ in cls.list_builtin_providers(): |
| pass |
|
|
| @classmethod |
| def clear_builtin_providers_cache(cls): |
| cls._builtin_providers = {} |
| cls._builtin_providers_loaded = False |
|
|
| @classmethod |
| def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]: |
| """ |
| get the tool label |
| |
| :param tool_name: the name of the tool |
| |
| :return: the label of the tool |
| """ |
| if len(cls._builtin_tools_labels) == 0: |
| |
| cls.load_builtin_providers_cache() |
|
|
| if tool_name not in cls._builtin_tools_labels: |
| return None |
|
|
| return cls._builtin_tools_labels[tool_name] |
|
|
| @classmethod |
| def user_list_providers( |
| cls, user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral |
| ) -> list[UserToolProvider]: |
| result_providers: dict[str, UserToolProvider] = {} |
|
|
| filters = [] |
| if not typ: |
| filters.extend(["builtin", "api", "workflow"]) |
| else: |
| filters.append(typ) |
|
|
| if "builtin" in filters: |
| |
| builtin_providers = cls.list_builtin_providers() |
|
|
| |
| db_builtin_providers: list[BuiltinToolProvider] = ( |
| db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() |
| ) |
|
|
| find_db_builtin_provider = lambda provider: next( |
| (x for x in db_builtin_providers if x.provider == provider), None |
| ) |
|
|
| |
| for provider in builtin_providers: |
| |
| if is_filtered( |
| include_set=dify_config.POSITION_TOOL_INCLUDES_SET, |
| exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, |
| data=provider, |
| name_func=lambda x: x.identity.name, |
| ): |
| continue |
|
|
| user_provider = ToolTransformService.builtin_provider_to_user_provider( |
| provider_controller=provider, |
| db_provider=find_db_builtin_provider(provider.identity.name), |
| decrypt_credentials=False, |
| ) |
|
|
| result_providers[provider.identity.name] = user_provider |
|
|
| |
|
|
| if "api" in filters: |
| db_api_providers: list[ApiToolProvider] = ( |
| db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() |
| ) |
|
|
| api_provider_controllers = [ |
| {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)} |
| for provider in db_api_providers |
| ] |
|
|
| |
| labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers]) |
|
|
| for api_provider_controller in api_provider_controllers: |
| user_provider = ToolTransformService.api_provider_to_user_provider( |
| provider_controller=api_provider_controller["controller"], |
| db_provider=api_provider_controller["provider"], |
| decrypt_credentials=False, |
| labels=labels.get(api_provider_controller["controller"].provider_id, []), |
| ) |
| result_providers[f"api_provider.{user_provider.name}"] = user_provider |
|
|
| if "workflow" in filters: |
| |
| workflow_providers: list[WorkflowToolProvider] = ( |
| db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all() |
| ) |
|
|
| workflow_provider_controllers = [] |
| for provider in workflow_providers: |
| try: |
| workflow_provider_controllers.append( |
| ToolTransformService.workflow_provider_to_controller(db_provider=provider) |
| ) |
| except Exception as e: |
| |
| pass |
|
|
| labels = ToolLabelManager.get_tools_labels(workflow_provider_controllers) |
|
|
| for provider_controller in workflow_provider_controllers: |
| user_provider = ToolTransformService.workflow_provider_to_user_provider( |
| provider_controller=provider_controller, |
| labels=labels.get(provider_controller.provider_id, []), |
| ) |
| result_providers[f"workflow_provider.{user_provider.name}"] = user_provider |
|
|
| return BuiltinToolProviderSort.sort(list(result_providers.values())) |
|
|
| @classmethod |
| def get_api_provider_controller( |
| cls, tenant_id: str, provider_id: str |
| ) -> tuple[ApiToolProviderController, dict[str, Any]]: |
| """ |
| get the api provider |
| |
| :param provider_name: the name of the provider |
| |
| :return: the provider controller, the credentials |
| """ |
| provider: ApiToolProvider = ( |
| db.session.query(ApiToolProvider) |
| .filter( |
| ApiToolProvider.id == provider_id, |
| ApiToolProvider.tenant_id == tenant_id, |
| ) |
| .first() |
| ) |
|
|
| if provider is None: |
| raise ToolProviderNotFoundError(f"api provider {provider_id} not found") |
|
|
| controller = ApiToolProviderController.from_db( |
| provider, |
| ApiProviderAuthType.API_KEY if provider.credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE, |
| ) |
| controller.load_bundled_tools(provider.tools) |
|
|
| return controller, provider.credentials |
|
|
| @classmethod |
| def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict: |
| """ |
| get api provider |
| """ |
| """ |
| get tool provider |
| """ |
| provider: ApiToolProvider = ( |
| db.session.query(ApiToolProvider) |
| .filter( |
| ApiToolProvider.tenant_id == tenant_id, |
| ApiToolProvider.name == provider, |
| ) |
| .first() |
| ) |
|
|
| if provider is None: |
| raise ValueError(f"you have not added provider {provider}") |
|
|
| try: |
| credentials = json.loads(provider.credentials_str) or {} |
| except: |
| credentials = {} |
|
|
| |
| controller = ApiToolProviderController.from_db( |
| provider, ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE |
| ) |
| |
| tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller) |
|
|
| decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) |
| masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) |
|
|
| try: |
| icon = json.loads(provider.icon) |
| except: |
| icon = {"background": "#252525", "content": "\ud83d\ude01"} |
|
|
| |
| labels = ToolLabelManager.get_tool_labels(controller) |
|
|
| return jsonable_encoder( |
| { |
| "schema_type": provider.schema_type, |
| "schema": provider.schema, |
| "tools": provider.tools, |
| "icon": icon, |
| "description": provider.description, |
| "credentials": masked_credentials, |
| "privacy_policy": provider.privacy_policy, |
| "custom_disclaimer": provider.custom_disclaimer, |
| "labels": labels, |
| } |
| ) |
|
|
| @classmethod |
| def get_tool_icon(cls, tenant_id: str, provider_type: str, provider_id: str) -> Union[str, dict]: |
| """ |
| get the tool icon |
| |
| :param tenant_id: the id of the tenant |
| :param provider_type: the type of the provider |
| :param provider_id: the id of the provider |
| :return: |
| """ |
| provider_type = provider_type |
| provider_id = provider_id |
| if provider_type == "builtin": |
| return ( |
| dify_config.CONSOLE_API_URL |
| + "/console/api/workspaces/current/tool-provider/builtin/" |
| + provider_id |
| + "/icon" |
| ) |
| elif provider_type == "api": |
| try: |
| provider: ApiToolProvider = ( |
| db.session.query(ApiToolProvider) |
| .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id) |
| .first() |
| ) |
| return json.loads(provider.icon) |
| except: |
| return {"background": "#252525", "content": "\ud83d\ude01"} |
| elif provider_type == "workflow": |
| provider: WorkflowToolProvider = ( |
| db.session.query(WorkflowToolProvider) |
| .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) |
| .first() |
| ) |
| if provider is None: |
| raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") |
|
|
| return json.loads(provider.icon) |
| else: |
| raise ValueError(f"provider type {provider_type} not found") |
|
|
|
|
| |
| Thread(target=ToolManager.load_builtin_providers_cache, name="pre_load_builtin_providers_cache", daemon=True).start() |
|
|